mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-04 19:45:57 +02:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e121edc432 | |||
| 2f099b510f | |||
| aa50ba462f | |||
| de2ef53a4b | |||
| c508256db2 | |||
| 40aaa8a403 | |||
| a08c1d2845 | |||
| d785f9c1fd | |||
| 4032ca4066 | |||
| 515fdbf7ed | |||
| f5cd27b71d | |||
| a2d02d5793 | |||
| 17fc817b58 | |||
| 2bd1b30f69 | |||
| 259469c4b5 | |||
| 4c32832c59 | |||
| c3a2624339 | |||
| ffd0eae60b | |||
| b775345d78 | |||
| a70a8a69c2 | |||
| d13d0f6135 | |||
| 8a2afb7520 | |||
| 9ecf3e66a3 | |||
| faaaff5f94 | |||
| e16c4731c7 | |||
| 1dcd01960c | |||
| c10ed6cbcc | |||
| a127ff1780 | |||
| 3079e9ac8e | |||
| 8a1d206f1d | |||
| 797990c4bc | |||
| ab86335760 | |||
| cc74d5be99 | |||
| 5be24af73d | |||
| d394a9aedc |
@@ -48,3 +48,7 @@ end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[tools/mtmd/miniaudio.h]
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
@@ -260,16 +260,18 @@ jobs:
|
||||
architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }}
|
||||
|
||||
- name: Build
|
||||
shell: cmd
|
||||
env:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
cmake -S . -B build -G "Ninja Multi-Config" `
|
||||
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake `
|
||||
-DGGML_NATIVE=OFF `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_CPU_ALL_VARIANTS=ON `
|
||||
-DGGML_OPENMP=OFF `
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" `
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }}
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_BACKEND_DL=ON ^
|
||||
-DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^
|
||||
-DGGML_OPENMP=ON ^
|
||||
-DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release
|
||||
|
||||
@@ -279,6 +281,7 @@ jobs:
|
||||
CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }}
|
||||
run: |
|
||||
Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\
|
||||
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
|
||||
7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
@@ -448,6 +451,7 @@ jobs:
|
||||
WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
|
||||
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
|
||||
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
@@ -513,7 +517,9 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
gpu_target: [gfx1100, gfx1101, gfx1030]
|
||||
include:
|
||||
- name: "radeon"
|
||||
gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -528,7 +534,7 @@ jobs:
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2.16
|
||||
with:
|
||||
key: windows-latest-cmake-hip-${{ matrix.gpu_target }}-x64
|
||||
key: windows-latest-cmake-hip-${{ matrix.name }}-x64
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Install
|
||||
@@ -554,9 +560,12 @@ jobs:
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_NATIVE=OFF `
|
||||
-DGGML_CPU=OFF `
|
||||
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGGML_HIP=ON `
|
||||
-DLLAMA_CURL=OFF
|
||||
@@ -569,13 +578,13 @@ jobs:
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
7z a llama-bin-win-hip-${{ matrix.gpu_target }}-x64.zip .\build\bin\*
|
||||
7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
path: llama-bin-win-hip-${{ matrix.gpu_target }}-x64.zip
|
||||
name: llama-bin-win-hip-${{ matrix.gpu_target }}-x64.zip
|
||||
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
|
||||
ios-xcode-build:
|
||||
runs-on: macos-latest
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
name: Update Winget Package
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
schedule:
|
||||
- cron: '28 5 * * *' # Update every day at 5:28 UTC
|
||||
|
||||
jobs:
|
||||
update:
|
||||
name: Update Winget Package
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Install cargo binstall
|
||||
uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b
|
||||
|
||||
- name: Install komac
|
||||
run: |
|
||||
cargo binstall komac@2.11.2 -y
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const { data: releases } = await github.rest.repos.listReleases({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
console.log("Latest release:", releases[0].tag_name);
|
||||
return releases[0].tag_name;
|
||||
|
||||
- name: Update manifest
|
||||
env:
|
||||
VERSION: ${{ steps.find_latest_release.outputs.result }}
|
||||
run: |
|
||||
echo "Updating manifest..."
|
||||
komac update --version ${{ env.VERSION }} \
|
||||
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
|
||||
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
|
||||
--submit \
|
||||
ggml.llamacpp
|
||||
@@ -580,3 +580,4 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc
|
||||
- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License
|
||||
- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License
|
||||
- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html)
|
||||
- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain
|
||||
|
||||
@@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
|
||||
base64.hpp
|
||||
chat.cpp
|
||||
chat.h
|
||||
chat-parser.cpp
|
||||
chat-parser.h
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
console.h
|
||||
json-schema-to-grammar.cpp
|
||||
json.hpp
|
||||
json-partial.h
|
||||
json-partial.cpp
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
log.h
|
||||
|
||||
+19
-10
@@ -39,7 +39,7 @@
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
std::initializer_list<enum llama_example> mmproj_examples = {
|
||||
LLAMA_EXAMPLE_LLAVA,
|
||||
LLAMA_EXAMPLE_MTMD,
|
||||
LLAMA_EXAMPLE_SERVER,
|
||||
};
|
||||
|
||||
@@ -2233,12 +2233,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD"));
|
||||
add_opt(common_arg(
|
||||
{"--image"}, "FILE",
|
||||
"path to an image file. use with multimodal models. Specify multiple times for batching",
|
||||
{"--image", "--audio"}, "FILE",
|
||||
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.image.emplace_back(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LLAVA}));
|
||||
).set_examples({LLAMA_EXAMPLE_MTMD}));
|
||||
if (llama_supports_rpc()) {
|
||||
add_opt(common_arg(
|
||||
{"--rpc"}, "SERVERS",
|
||||
@@ -2848,15 +2848,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"reasoning format (default: deepseek; allowed values: deepseek, none)\n"
|
||||
"controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n"
|
||||
"only supported for non-streamed responses",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||
"(default: deepseek)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
|
||||
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget"}, "N",
|
||||
"controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
|
||||
[](common_params & params, int value) {
|
||||
if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
|
||||
params.reasoning_budget = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
@@ -2868,7 +2877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.chat_template = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||
string_format(
|
||||
@@ -2955,7 +2964,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
/**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; }
|
||||
else if (value == "md") { params.batched_bench_output_jsonl = false; }
|
||||
else { std::invalid_argument("invalid value"); }
|
||||
else { throw std::invalid_argument("invalid value"); }
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -0,0 +1,376 @@
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
|
||||
while (true) {
|
||||
std::string id = std::to_string(std::rand());
|
||||
if (input.find(id) == std::string::npos) {
|
||||
healing_marker_ = id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
|
||||
GGML_ASSERT(rng.begin <= rng.end);
|
||||
return input_.substr(rng.begin, rng.end - rng.begin);
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_content(const std::string &content) {
|
||||
result_.content += content;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
|
||||
result_.reasoning_content += reasoning_content;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
|
||||
if (name.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
common_chat_tool_call tool_call;
|
||||
tool_call.name = name;
|
||||
tool_call.arguments = arguments;
|
||||
tool_call.id = id;
|
||||
|
||||
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
|
||||
result_.tool_calls.emplace_back(tool_call);
|
||||
return true;
|
||||
}
|
||||
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
|
||||
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
|
||||
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
|
||||
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
|
||||
return add_tool_call(name, id, arguments);
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
||||
for (const auto & item : arr) {
|
||||
if (!add_tool_call(item)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void common_chat_msg_parser::finish() {
|
||||
if (!is_partial_ && pos_ != input_.size()) {
|
||||
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||
}
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::consume_spaces() {
|
||||
const auto length = input_.size();
|
||||
auto consumed = false;
|
||||
while (pos_ < length && std::isspace(input_[pos_])) {
|
||||
++pos_;
|
||||
consumed = true;
|
||||
}
|
||||
return consumed;
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
|
||||
auto pos = pos_;
|
||||
for (auto i = 0u; i < literal.size(); ++i) {
|
||||
if (pos >= input_.size()) {
|
||||
return false;
|
||||
}
|
||||
if (input_[pos] != literal[i]) {
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
}
|
||||
pos_ = pos;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||
auto idx = input_.find(literal, pos_);
|
||||
if (idx != std::string::npos) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = idx + literal.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
if (is_partial_) {
|
||||
idx = string_find_partial_stop(input_, literal);
|
||||
if (idx != std::string::npos && idx >= pos_) {
|
||||
find_regex_result res;
|
||||
res.prelude = input_.substr(pos_, idx - pos_);
|
||||
auto end = input_.size();
|
||||
res.groups.emplace_back(common_string_range{idx, end});
|
||||
move_to(end);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||
if (!try_consume_literal(literal)) {
|
||||
throw common_chat_msg_partial_exception(literal);
|
||||
}
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||
auto stripped_reasoning = string_strip(reasoning);
|
||||
if (stripped_reasoning.empty()) {
|
||||
return;
|
||||
}
|
||||
if (syntax_.reasoning_in_content) {
|
||||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
|
||||
add_content(stripped_reasoning);
|
||||
if (closed) {
|
||||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||
}
|
||||
} else {
|
||||
add_reasoning_content(stripped_reasoning);
|
||||
}
|
||||
};
|
||||
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||
if (auto res = try_find_literal(end_think)) {
|
||||
handle_reasoning(res->prelude, /* closed */ true);
|
||||
consume_spaces();
|
||||
return true;
|
||||
}
|
||||
auto rest = consume_rest();
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
if (!syntax_.thinking_forced_open) {
|
||||
throw common_chat_msg_partial_exception(end_think);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
auto rest = input_.substr(pos_);
|
||||
pos_ = input_.size();
|
||||
return rest;
|
||||
}
|
||||
|
||||
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
|
||||
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
return find_regex_result{prelude, m.groups};
|
||||
}
|
||||
|
||||
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||
if (auto result = try_consume_regex(regex)) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||
auto m = regex.search(input_, pos_);
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||
if (is_partial()) {
|
||||
throw common_chat_msg_partial_exception(regex.str());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
if (m.groups[0].begin != pos_) {
|
||||
// Didn't match at the current position.
|
||||
return std::nullopt;
|
||||
}
|
||||
pos_ = m.groups[0].end;
|
||||
|
||||
return find_regex_result {
|
||||
/* .prelude = */ "",
|
||||
m.groups,
|
||||
};
|
||||
}
|
||||
|
||||
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
|
||||
auto it = input_.cbegin() + pos_;
|
||||
const auto end = input_.cend();
|
||||
common_json result;
|
||||
if (!common_json_parse(it, end, healing_marker_, result)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
pos_ = std::distance(input_.cbegin(), it);
|
||||
if (result.healing_marker.marker.empty()) {
|
||||
// No healing marker, just return the parsed json
|
||||
return result;
|
||||
}
|
||||
if (!is_partial()) {
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_json common_chat_msg_parser::consume_json() {
|
||||
if (auto result = try_consume_json()) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
|
||||
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
|
||||
return *result;
|
||||
}
|
||||
throw common_chat_msg_partial_exception("JSON");
|
||||
}
|
||||
|
||||
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths,
|
||||
const std::vector<std::vector<std::string>> & content_paths
|
||||
) {
|
||||
auto partial = try_consume_json();
|
||||
if (!partial) {
|
||||
return std::nullopt;
|
||||
}
|
||||
auto is_arguments_path = [&](const std::vector<std::string> & path) {
|
||||
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
|
||||
};
|
||||
auto is_content_path = [&](const std::vector<std::string> & path) {
|
||||
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
|
||||
};
|
||||
|
||||
if (partial->healing_marker.marker.empty()) {
|
||||
if (args_paths.empty()) {
|
||||
// No arguments to dump, and JSON was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json,
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
|
||||
auto found_healing_marker = false;
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
arguments.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
if (arguments == "\"") {
|
||||
// This happens because of completing `:"$magic` after `"arguments"`
|
||||
arguments = "";
|
||||
}
|
||||
}
|
||||
return arguments;
|
||||
}
|
||||
if (is_content_path(path)) {
|
||||
if (!j.is_string()) {
|
||||
throw std::runtime_error("Content path must be a string");
|
||||
}
|
||||
std::string str = j;
|
||||
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
|
||||
if (idx != std::string::npos) {
|
||||
str.resize(idx);
|
||||
found_healing_marker = true;
|
||||
}
|
||||
return str;
|
||||
}
|
||||
if (j.is_object()) {
|
||||
auto obj = json::object();
|
||||
for (const auto & p : j.items()) {
|
||||
const auto & key = p.key();
|
||||
const auto & value = p.value();
|
||||
const std::string key_str = key; // NOLINT
|
||||
auto idx = key_str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
path.push_back(key_str);
|
||||
if (value.is_string()) {
|
||||
const std::string value_str = value;
|
||||
if (value_str.find(healing_marker_) != std::string::npos) {
|
||||
found_healing_marker = true;
|
||||
if (is_content_path(path)) {
|
||||
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
|
||||
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
obj[key] = value;
|
||||
} else {
|
||||
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||
}
|
||||
path.pop_back();
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
if (j.is_array()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & value : j) {
|
||||
if (value.is_string()) {
|
||||
std::string str = value;
|
||||
auto idx = str.find(healing_marker_);
|
||||
if (idx != std::string::npos) {
|
||||
// Don't heal array values that aren't in the arguments.
|
||||
found_healing_marker = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
arr.push_back(remove_unsupported_healings_and_dump_args(value));
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
return j;
|
||||
};
|
||||
|
||||
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
|
||||
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||
return consume_json_result {
|
||||
cleaned,
|
||||
/* .is_partial = */ found_healing_marker,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
#pragma once
|
||||
|
||||
#include "chat.h"
|
||||
#include "json-partial.h"
|
||||
#include "json.hpp"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
public:
|
||||
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_syntax syntax_;
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
throw std::runtime_error("Invalid position!");
|
||||
}
|
||||
pos_ = pos;
|
||||
}
|
||||
void move_back(size_t n) {
|
||||
if (pos_ < n) {
|
||||
throw std::runtime_error("Can't move back that far!");
|
||||
}
|
||||
pos_ -= n;
|
||||
}
|
||||
|
||||
// Get the substring of the input at the given range
|
||||
std::string str(const common_string_range & rng) const;
|
||||
|
||||
// Appends to the result.content field
|
||||
void add_content(const std::string & content);
|
||||
|
||||
// Appends to the result.reasoning_content field
|
||||
void add_reasoning_content(const std::string & reasoning_content);
|
||||
|
||||
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||
|
||||
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||
|
||||
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||
|
||||
void finish();
|
||||
|
||||
bool consume_spaces();
|
||||
|
||||
void consume_literal(const std::string & literal);
|
||||
|
||||
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||
|
||||
std::string consume_rest();
|
||||
|
||||
struct find_regex_result {
|
||||
std::string prelude;
|
||||
std::vector<common_string_range> groups;
|
||||
};
|
||||
|
||||
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
|
||||
|
||||
bool try_consume_literal(const std::string & literal);
|
||||
|
||||
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||
|
||||
find_regex_result consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||
|
||||
std::optional<common_json> try_consume_json();
|
||||
common_json consume_json();
|
||||
|
||||
struct consume_json_result {
|
||||
nlohmann::ordered_json value;
|
||||
bool is_partial;
|
||||
};
|
||||
|
||||
/*
|
||||
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||
|
||||
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||
|
||||
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||
*/
|
||||
consume_json_result consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||
);
|
||||
};
|
||||
+709
-600
File diff suppressed because it is too large
Load Diff
+70
-6
@@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -13,11 +14,19 @@ struct common_chat_tool_call {
|
||||
std::string name;
|
||||
std::string arguments;
|
||||
std::string id;
|
||||
|
||||
bool operator==(const common_chat_tool_call & other) const {
|
||||
return name == other.name && arguments == other.arguments && id == other.id;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_content_part {
|
||||
std::string type;
|
||||
std::string text;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
@@ -28,6 +37,51 @@ struct common_chat_msg {
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
auto id = tool_calls[i].id;
|
||||
if (id.empty()) {
|
||||
id = gen_tool_call_id();
|
||||
}
|
||||
ids_cache.push_back(id);
|
||||
}
|
||||
tool_calls[i].id = ids_cache[i];
|
||||
}
|
||||
}
|
||||
bool operator==(const common_chat_msg & other) const {
|
||||
return role == other.role
|
||||
&& content == other.content
|
||||
&& content_parts == other.content_parts
|
||||
&& tool_calls == other.tool_calls
|
||||
&& reasoning_content == other.reasoning_content
|
||||
&& tool_name == other.tool_name
|
||||
&& tool_call_id == other.tool_call_id;
|
||||
}
|
||||
bool operator!=(const common_chat_msg & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_diff {
|
||||
// std::string reasoning_content_delta;
|
||||
std::string content_delta;
|
||||
size_t tool_call_index = std::string::npos;
|
||||
common_chat_tool_call tool_call_delta;
|
||||
|
||||
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||
|
||||
bool operator==(const common_chat_msg_diff & other) const {
|
||||
return content_delta == other.content_delta
|
||||
&& tool_call_index == other.tool_call_index
|
||||
&& tool_call_delta == other.tool_call_delta;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -49,14 +103,11 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
@@ -71,7 +122,8 @@ struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
@@ -80,11 +132,20 @@ struct common_chat_params {
|
||||
std::string prompt;
|
||||
std::string grammar;
|
||||
bool grammar_lazy = false;
|
||||
bool thinking_forced_open = false;
|
||||
std::vector<common_grammar_trigger> grammar_triggers;
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
@@ -121,8 +182,9 @@ std::string common_chat_format_example(
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool use_jinja);
|
||||
|
||||
std::string common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
@@ -135,3 +197,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
+1
-1
@@ -849,7 +849,7 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
|
||||
+3
-2
@@ -76,7 +76,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_SERVER,
|
||||
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
||||
LLAMA_EXAMPLE_EXPORT_LORA,
|
||||
LLAMA_EXAMPLE_LLAVA,
|
||||
LLAMA_EXAMPLE_MTMD,
|
||||
LLAMA_EXAMPLE_LOOKUP,
|
||||
LLAMA_EXAMPLE_PARALLEL,
|
||||
LLAMA_EXAMPLE_TTS,
|
||||
@@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
};
|
||||
|
||||
struct common_grammar_trigger {
|
||||
@@ -368,6 +368,7 @@ struct common_params {
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
|
||||
@@ -0,0 +1,255 @@
|
||||
#include <json-partial.h>
|
||||
#include "ggml.h"
|
||||
#include "log.h"
|
||||
#include <string>
|
||||
|
||||
#include <json.hpp>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||
};
|
||||
|
||||
struct common_json_stack_element {
|
||||
common_json_stack_element_type type;
|
||||
std::string key;
|
||||
};
|
||||
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
std::string::const_iterator it = input.begin();
|
||||
const auto end = input.end();
|
||||
return common_json_parse(it, end, healing_marker, out);
|
||||
}
|
||||
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
std::string last_token;
|
||||
std::string exception_message;
|
||||
std::vector<common_json_stack_element> stack;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
this->last_token = last_token;
|
||||
this->exception_message = ex.what();
|
||||
return false;
|
||||
}
|
||||
void close_value() {
|
||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||
stack.pop_back();
|
||||
}
|
||||
}
|
||||
bool null() override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool boolean(bool) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_integer(number_integer_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool string(string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool binary(binary_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool start_object(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_object() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool key(string_t & key) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||
return true;
|
||||
}
|
||||
bool start_array(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_array() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
auto start = it;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
if (err_loc.found_error) {
|
||||
it = start;
|
||||
auto temptative_end = it + err_loc.position;
|
||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||
|
||||
auto input = std::string(it, temptative_end);
|
||||
try {
|
||||
out.json = json::parse(input);
|
||||
// out.json = json::parse(it, temptative_end);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
auto _ = json::parse(str); // NOLINT
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||
if (last_non_sp_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
auto last_non_sp_char = str[last_non_sp_pos];
|
||||
// Used to detect stops on a number, which may not be complete.
|
||||
auto was_maybe_number = [&]() {
|
||||
if (!str.empty() && std::isspace(str.back())) {
|
||||
return false;
|
||||
}
|
||||
return std::isdigit(last_non_sp_char) ||
|
||||
last_non_sp_char == '.' ||
|
||||
last_non_sp_char == 'e' ||
|
||||
last_non_sp_char == 'E' ||
|
||||
last_non_sp_char == '-';
|
||||
};
|
||||
|
||||
std::string closing;
|
||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||
auto & el = err_loc.stack[i - 1];
|
||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
closing += "}";
|
||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
closing += "]";
|
||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
throw std::runtime_error("Unexpected stack element type");
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
// We're inside an object value
|
||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an object value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + ": 1" + closing)) {
|
||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||
// Was about to create an object
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an object value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to opening : for object value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an array value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an array value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of("[,");
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to last [ or , for array value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\": 1" + closing)) {
|
||||
// Was inside an object key string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
it = end;
|
||||
return true;
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#pragma once
|
||||
#include <json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
// Raw marker.
|
||||
std::string marker;
|
||||
|
||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||
std::string json_dump_marker;
|
||||
};
|
||||
|
||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||
struct common_json {
|
||||
nlohmann::ordered_json json;
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
};
|
||||
|
||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||
//
|
||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||
//
|
||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
|
||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
+7
-8
@@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
} else {
|
||||
std::vector<std::string> patterns_at_start;
|
||||
std::vector<std::string> trigger_patterns;
|
||||
std::vector<std::string> patterns_anywhere;
|
||||
std::vector<llama_token> trigger_tokens;
|
||||
for (const auto & trigger : params.grammar_triggers) {
|
||||
@@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
||||
{
|
||||
const auto & pattern = trigger.value;
|
||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
||||
patterns_anywhere.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||
{
|
||||
trigger_patterns.push_back(trigger.value);
|
||||
break;
|
||||
}
|
||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||
@@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> trigger_patterns;
|
||||
if (!patterns_at_start.empty()) {
|
||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
if (!patterns_anywhere.empty()) {
|
||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||
}
|
||||
|
||||
+135
-45
@@ -45,7 +45,7 @@ class SentencePieceTokenTypes(IntEnum):
|
||||
|
||||
class ModelType(IntEnum):
|
||||
TEXT = 1
|
||||
VISION = 2
|
||||
MMPROJ = 2
|
||||
|
||||
|
||||
AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
||||
@@ -54,7 +54,7 @@ AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
|
||||
class ModelBase:
|
||||
_model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
|
||||
ModelType.TEXT: {},
|
||||
ModelType.VISION: {},
|
||||
ModelType.MMPROJ: {},
|
||||
}
|
||||
|
||||
dir_model: Path
|
||||
@@ -88,7 +88,7 @@ class ModelBase:
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is VisionModel:
|
||||
type(self) is MmprojModel:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
self.dir_model = dir_model
|
||||
@@ -309,6 +309,7 @@ class ModelBase:
|
||||
gguf.MODEL_TENSOR.POSNET_NORM1,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
)
|
||||
)
|
||||
or not new_name.endswith(".weight")
|
||||
@@ -438,7 +439,7 @@ class ModelBase:
|
||||
assert names
|
||||
|
||||
def func(modelcls: AnyModel) -> AnyModel:
|
||||
model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
|
||||
model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT
|
||||
for name in names:
|
||||
cls._model_classes[model_type][name] = modelcls
|
||||
return modelcls
|
||||
@@ -1114,60 +1115,87 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_pooling_type(pooling_type)
|
||||
|
||||
|
||||
class VisionModel(ModelBase):
|
||||
model_type = ModelType.VISION
|
||||
model_arch = gguf.MODEL_ARCH.CLIP_VISION
|
||||
class MmprojModel(ModelBase):
|
||||
model_type = ModelType.MMPROJ
|
||||
model_arch = gguf.MODEL_ARCH.MMPROJ
|
||||
preprocessor_config: dict[str, Any]
|
||||
global_config: dict[str, Any]
|
||||
|
||||
has_vision_encoder: bool = True # by default
|
||||
has_audio_encoder: bool = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
|
||||
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
||||
if self.model_arch != gguf.MODEL_ARCH.MMPROJ:
|
||||
raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ")
|
||||
|
||||
if self.has_vision_encoder and self.has_audio_encoder:
|
||||
raise NotImplementedError("both vision + audio not supported yet")
|
||||
|
||||
# get n_embd of the text model
|
||||
if "text_config" not in self.hparams:
|
||||
self.hparams["text_config"] = {}
|
||||
if "audio_config" not in self.hparams:
|
||||
self.hparams["audio_config"] = {}
|
||||
text_config = {**self.hparams, **self.hparams["text_config"]}
|
||||
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
||||
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
||||
|
||||
if "vision_config" not in self.hparams:
|
||||
raise ValueError("vision_config not found in hparams")
|
||||
# move vision config to the top level, while preserving the original hparams in global_config
|
||||
self.global_config = self.hparams
|
||||
self.hparams = self.hparams["vision_config"]
|
||||
|
||||
if "vision_config" in self.hparams:
|
||||
self.hparams = self.hparams["vision_config"]
|
||||
elif "audio_config" in self.hparams:
|
||||
self.hparams = self.hparams["audio_config"]
|
||||
else:
|
||||
raise ValueError("vision_config / audio_config not found in hparams")
|
||||
|
||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
|
||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
|
||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||
|
||||
# load preprocessor config
|
||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
|
||||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
|
||||
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
||||
self.gguf_writer.add_vision_has_vision_encoder(True)
|
||||
|
||||
# vision config
|
||||
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
|
||||
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_vision_block_count(self.block_count)
|
||||
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
||||
if self.has_vision_encoder:
|
||||
self.gguf_writer.add_clip_has_vision_encoder(True)
|
||||
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
||||
|
||||
# preprocessor config
|
||||
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
||||
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
|
||||
# vision config
|
||||
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
|
||||
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_vision_block_count(self.block_count)
|
||||
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
|
||||
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"])
|
||||
|
||||
elif self.has_audio_encoder:
|
||||
self.gguf_writer.add_clip_has_audio_encoder(True)
|
||||
self.gguf_writer.add_audio_projection_dim(self.n_embd_text)
|
||||
|
||||
# audio config
|
||||
self.gguf_writer.add_audio_embedding_length(self.find_hparam(["hidden_size"]))
|
||||
self.gguf_writer.add_audio_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_audio_block_count(self.block_count)
|
||||
self.gguf_writer.add_audio_head_count(self.find_hparam(["num_attention_heads"]))
|
||||
|
||||
else:
|
||||
raise ValueError("MmprojModel must have either vision or audio encoder")
|
||||
|
||||
def write_vocab(self):
|
||||
raise ValueError("VisionModel does not support vocab writing")
|
||||
raise ValueError("MmprojModel does not support vocab writing")
|
||||
|
||||
|
||||
@ModelBase.register("GPTNeoXForCausalLM")
|
||||
@@ -1951,7 +1979,7 @@ class LlamaModel(TextModel):
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
)
|
||||
class LlavaVisionModel(VisionModel):
|
||||
class LlavaVisionModel(MmprojModel):
|
||||
img_break_tok_id = -1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -1977,7 +2005,7 @@ class LlavaVisionModel(VisionModel):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if hparams["model_type"] == "pixtral":
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||
|
||||
# hidden_act
|
||||
@@ -2016,7 +2044,7 @@ class LlavaVisionModel(VisionModel):
|
||||
|
||||
|
||||
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
|
||||
class SmolVLMModel(VisionModel):
|
||||
class SmolVLMModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams["model_type"] == "smolvlm_vision":
|
||||
@@ -2028,7 +2056,7 @@ class SmolVLMModel(VisionModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
@@ -2094,10 +2122,10 @@ class Llama4Model(LlamaModel):
|
||||
|
||||
|
||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||
class Llama4VisionModel(VisionModel):
|
||||
class Llama4VisionModel(MmprojModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
|
||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
|
||||
assert self.hparams["hidden_act"] == "gelu"
|
||||
@@ -2615,7 +2643,7 @@ class QwenModel(TextModel):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM")
|
||||
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
|
||||
class Qwen2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2
|
||||
|
||||
@@ -2639,8 +2667,9 @@ class Qwen2Model(TextModel):
|
||||
name = f"model.{name}" # map to Qwen2ForCausalLM tensors
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "") # for InternVL
|
||||
if name.startswith("mlp") or name.startswith("vision_model"):
|
||||
# skip visual tensors
|
||||
if name.startswith("mlp") or name.startswith("multi_modal_projector") \
|
||||
or name.startswith("vision_model") or name.startswith("audio_tower"):
|
||||
# skip vision and audio tensors
|
||||
return []
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
@@ -2670,7 +2699,7 @@ class Qwen2VLModel(TextModel):
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
|
||||
class Qwen2VLVisionModel(VisionModel):
|
||||
class Qwen2VLVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["image_size"] = self.hparams.get("image_size", 560)
|
||||
@@ -2685,9 +2714,9 @@ class Qwen2VLVisionModel(VisionModel):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if self.global_config['model_type'] == 'qwen2_vl':
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL)
|
||||
elif self.global_config['model_type'] == 'qwen2_5_vl':
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL)
|
||||
self.gguf_writer.add_vision_use_silu(True)
|
||||
# find n_wa_pattern (window attention pattern)
|
||||
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
|
||||
@@ -2746,11 +2775,11 @@ class Qwen2VLVisionModel(VisionModel):
|
||||
|
||||
|
||||
@ModelBase.register("InternVisionModel")
|
||||
class InternVisionModel(VisionModel):
|
||||
class InternVisionModel(MmprojModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||
# hidden_act
|
||||
if hparams["hidden_act"] == "silu":
|
||||
@@ -4008,11 +4037,11 @@ class Gemma3Model(TextModel):
|
||||
|
||||
|
||||
@ModelBase.register("Gemma3ForConditionalGeneration")
|
||||
class Gemma3VisionModel(VisionModel):
|
||||
class Gemma3VisionModel(MmprojModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3)
|
||||
# default values below are taken from HF tranformers code
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
@@ -5959,6 +5988,65 @@ class ChameleonModel(TextModel):
|
||||
return data_torch
|
||||
|
||||
|
||||
@ModelBase.register("UltravoxModel")
|
||||
class UltravoxModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA # dummy
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument")
|
||||
|
||||
|
||||
@ModelBase.register("Qwen2AudioForConditionalGeneration")
|
||||
class WhisperEncoderModel(MmprojModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
has_audio_encoder = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["hidden_size"] = self.hparams["d_model"]
|
||||
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
|
||||
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A)
|
||||
self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"])
|
||||
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
del bid, new_name, n_dims # unused
|
||||
if ".conv" in name and ".weight" in name:
|
||||
return gguf.GGMLQuantizationType.F16
|
||||
return False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("language_model."):
|
||||
# skip language model tensors
|
||||
return []
|
||||
|
||||
# prevent clash naming with vision tensors
|
||||
if name.startswith("multi_modal_projector"):
|
||||
name = "audio." + name
|
||||
|
||||
if "conv1.bias" in name or "conv2.bias" in name:
|
||||
# transpose conv1 and conv2 bias
|
||||
data_torch = data_torch.unsqueeze(-1)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("UltravoxModel")
|
||||
class UltravoxWhisperEncoderModel(WhisperEncoderModel):
|
||||
has_vision_encoder = False # no vision encoder
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
@@ -6134,13 +6222,15 @@ def split_str_to_n_bytes(split_str: str) -> int:
|
||||
|
||||
|
||||
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
|
||||
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
|
||||
# maybe we should fallback to text model's arch in that case, since not many models have both
|
||||
text_config = hparams.get("text_config", {})
|
||||
vision_config = hparams.get("vision_config", {})
|
||||
arch = hparams["architectures"][0]
|
||||
# if "architectures" is found in the sub-config, use that instead
|
||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||
arch = text_config["architectures"][0]
|
||||
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
||||
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
|
||||
arch = vision_config["architectures"][0]
|
||||
return arch
|
||||
|
||||
@@ -6203,7 +6293,7 @@ def main() -> None:
|
||||
|
||||
with torch.inference_mode():
|
||||
output_type = ftype_map[args.outtype]
|
||||
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
|
||||
model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT
|
||||
hparams = ModelBase.load_hparams(dir_model)
|
||||
model_architecture = get_model_architecture(hparams, model_type)
|
||||
logger.info(f"Model architecture: {model_architecture}")
|
||||
|
||||
+53
-24
@@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
|
||||
> [!TIP]
|
||||
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
|
||||
|
||||
> [!CAUTION]
|
||||
> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
|
||||
|
||||
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/chat/completions -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"python",
|
||||
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"code":{
|
||||
"type":"string",
|
||||
"description":"The code to run in the ipython interpreter."
|
||||
"model": "gpt-3.5-turbo",
|
||||
"tools": [
|
||||
{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"python",
|
||||
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"code":{
|
||||
"type":"string",
|
||||
"description":"The code to run in the ipython interpreter."
|
||||
}
|
||||
},
|
||||
"required":["code"]
|
||||
}
|
||||
},
|
||||
"required":["code"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Print a hello world message with python."
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Print a hello world message with python."
|
||||
}
|
||||
]
|
||||
}'
|
||||
|
||||
|
||||
curl http://localhost:8080/v1/chat/completions -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"}
|
||||
],
|
||||
"tools": [{
|
||||
"type":"function",
|
||||
"function":{
|
||||
"name":"get_current_weather",
|
||||
"description":"Get the current weather in a given location",
|
||||
"parameters":{
|
||||
"type":"object",
|
||||
"properties":{
|
||||
"location":{
|
||||
"type":"string",
|
||||
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
|
||||
}
|
||||
},
|
||||
"required":["location"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
+22
-2
@@ -4,7 +4,9 @@ llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools
|
||||
- [llama-mtmd-cli](../tools/mtmd/README.md)
|
||||
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
|
||||
|
||||
To enable it, can use use one of the 2 methods below:
|
||||
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
|
||||
|
||||
To enable it, you can use one of the 2 methods below:
|
||||
|
||||
- Use `-hf` option with a supported model (see a list of pre-quantized model below)
|
||||
- To load a model using `-hf` while disabling multimodal, use `--no-mmproj`
|
||||
@@ -31,12 +33,14 @@ llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/ggml-org
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/collections/ggml-org/multimodal-ggufs-68244e01ff1f39e5bebeeedc
|
||||
|
||||
Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server`
|
||||
|
||||
NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
**Vision models**:
|
||||
|
||||
```sh
|
||||
# Gemma 3
|
||||
(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
@@ -77,4 +81,20 @@ NOTE: some models may require large context window, for example: `-c 8192`
|
||||
|
||||
# Llama 4 Scout
|
||||
(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
|
||||
|
||||
# Moondream2 20250414 version
|
||||
(tool_name) -hf ggml-org/moondream2-20250414-GGUF
|
||||
|
||||
```
|
||||
|
||||
**Audio models**:
|
||||
|
||||
```sh
|
||||
# Ultravox 0.5
|
||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF
|
||||
(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF
|
||||
|
||||
# Qwen2-Audio and SeaLLM-Audio
|
||||
# note: no pre-quantized GGUF this model, as they have very poor result
|
||||
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
|
||||
```
|
||||
|
||||
+2
-2
@@ -528,15 +528,15 @@ extern "C" {
|
||||
GGML_UNARY_OP_STEP,
|
||||
GGML_UNARY_OP_TANH,
|
||||
GGML_UNARY_OP_ELU,
|
||||
GGML_UNARY_OP_RELU,
|
||||
GGML_UNARY_OP_SIGMOID,
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
GGML_UNARY_OP_SILU,
|
||||
GGML_UNARY_OP_HARDSWISH,
|
||||
GGML_UNARY_OP_HARDSIGMOID,
|
||||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_RELU,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
|
||||
@@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
|
||||
}
|
||||
}
|
||||
|
||||
// GroupedMatmulV2 required tensor_list.size < 128
|
||||
size_t GROUP_SIZE = 128;
|
||||
std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
|
||||
std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
|
||||
|
||||
// split and call GroupedMatmulV2
|
||||
// GroupedMatmulV2 required tensor_list.size < 128
|
||||
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
|
||||
// split and call GroupedMatmulV2
|
||||
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
|
||||
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
|
||||
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
|
||||
@@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs expert-specific matrix multiplication (MoE) with
|
||||
* quantized precision using the CANN backend.
|
||||
*
|
||||
* This function executes a matrix multiplication operation tailored for
|
||||
* Mixture of Experts (MoE) models, where the input tensor is multiplied
|
||||
* with expert-specific quantized weight matrices. It leverages the CANN
|
||||
* backend to perform efficient low-precision computations and stores the
|
||||
* quantized result in the destination tensor `dst`.
|
||||
*
|
||||
* Quantization techniques reduce memory footprint and improve performance
|
||||
* by using lower-bit representations (e.g., int8) instead of floating-point.
|
||||
* This function is designed to work with such formats and may incorporate
|
||||
* optimizations like identity-based fast paths or routing masks for sparse
|
||||
* expert selection.
|
||||
*
|
||||
* @param ctx The context for executing CANN backend operations.
|
||||
* @param dst The destination tensor where the quantized MoE multiplication result
|
||||
* will be stored.
|
||||
*
|
||||
* @note This function assumes quantized data types and is designed for
|
||||
* MoE architectures with potential sparse expert routing.
|
||||
*/
|
||||
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
// TODO: Use aclnnGroupedMatMul
|
||||
//dst [M, K, N, 1]
|
||||
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
|
||||
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
|
||||
ggml_tensor * ids = dst->src[2]; //ids [K, N]
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// copy index from npu to cpu
|
||||
int64_t n_as = ne02; // A
|
||||
int64_t n_ids = ids->ne[0]; // K
|
||||
|
||||
std::vector<char> ids_host(ggml_nbytes(ids));
|
||||
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids),
|
||||
ACL_MEMCPY_DEVICE_TO_HOST);
|
||||
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
|
||||
|
||||
char * src0_original = (char *) src0->data;
|
||||
char * src1_original = (char *) src1->data;
|
||||
char * dst_original = (char *) dst->data;
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
const enum ggml_type type = dst->src[0]->type;
|
||||
float weight_elem_size;
|
||||
if (type == GGML_TYPE_Q4_0) {
|
||||
weight_elem_size = float(sizeof(uint8_t)) / 2;
|
||||
} else if (type == GGML_TYPE_Q8_0) {
|
||||
weight_elem_size = float(sizeof(uint8_t));
|
||||
} else {
|
||||
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
|
||||
}
|
||||
|
||||
// src0_row [D, M, 1, 1] weight without permute
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
src0_row.nb[0] = weight_elem_size;
|
||||
src0_row.nb[1] = weight_elem_size * ne00;
|
||||
src0_row.nb[2] = weight_elem_size * ne00;
|
||||
src0_row.nb[3] = weight_elem_size * ne00;
|
||||
size_t weight_stride = ne00 * ne01 * weight_elem_size;
|
||||
size_t weight_size = weight_stride * ne02 * ne03;
|
||||
|
||||
// scale [D, M, 1, 1] -> scale && permute
|
||||
size_t scale_elem_size = sizeof(uint16_t);
|
||||
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
|
||||
|
||||
// src1_row [D, 1, 1, 1] -> input
|
||||
src1_row.ne[1] = 1;
|
||||
src1_row.ne[2] = 1;
|
||||
src1_row.ne[3] = 1;
|
||||
src1_row.nb[2] = nb11;
|
||||
src1_row.nb[3] = nb11;
|
||||
|
||||
// dst_row [M, 1, 1, 1] -> out
|
||||
dst_row.ne[1] = 1;
|
||||
dst_row.ne[2] = 1;
|
||||
dst_row.ne[3] = 1;
|
||||
dst_row.nb[2] = nb1;
|
||||
dst_row.nb[3] = nb1;
|
||||
|
||||
//create weight for one row
|
||||
ggml_cann_pool_alloc weight_allocator(ctx.pool());
|
||||
void* weight_buffer = weight_allocator.alloc(nb02);
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
||||
for (int64_t id = 0; id < n_ids; id++) {
|
||||
// expert index
|
||||
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
// If B = 1 (broadcast), always use 0; otherwise, use id.
|
||||
int64_t i11 = (ne11 == 1 ? 0 : id);
|
||||
int64_t i12 = iid1;
|
||||
|
||||
int64_t i1 = id;
|
||||
int64_t i2 = i12;
|
||||
|
||||
void* src0_tmp_ptr = src0_original + i02*weight_stride;
|
||||
void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
|
||||
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
|
||||
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
// mem cpy
|
||||
ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE);
|
||||
void* scale_buffer = (char*)weight_buffer + weight_stride;
|
||||
ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride,
|
||||
ACL_MEMCPY_DEVICE_TO_DEVICE);
|
||||
|
||||
src0_row.data = weight_buffer;
|
||||
src1_row.data = src1_tmp_ptr;
|
||||
dst_row.data = dst_tmp_ptr;
|
||||
dst_row.src[0] = &src0_row;
|
||||
dst_row.src[1] = &src1_row;
|
||||
|
||||
ggml_cann_mul_mat(ctx, &dst_row);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
const enum ggml_type type = dst->src[0]->type;
|
||||
switch (type) {
|
||||
@@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
case GGML_TYPE_F16:
|
||||
ggml_cann_mul_mat_id_fp(ctx, dst);
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
ggml_cann_mul_mat_id_quant(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for mul_mat_id");
|
||||
break;
|
||||
|
||||
@@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
return ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3484,6 +3484,19 @@ void ggml_cpu_init(void) {
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
|
||||
#ifdef GGML_USE_OPENMP
|
||||
//if (!getenv("OMP_WAIT_POLICY")) {
|
||||
// // set the wait policy to active, so that OpenMP threads don't sleep
|
||||
// putenv("OMP_WAIT_POLICY=active");
|
||||
//}
|
||||
|
||||
if (!getenv("KMP_BLOCKTIME")) {
|
||||
// set the time to wait before sleeping a thread
|
||||
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
||||
putenv("KMP_BLOCKTIME=200"); // 200ms
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(__ARM_ARCH)
|
||||
|
||||
@@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
__syncthreads();
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
@@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
}
|
||||
}
|
||||
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||
__syncthreads();
|
||||
continue;
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
@@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_SILU:
|
||||
ggml_cuda_op_silu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
ggml_cuda_op_gelu_erf(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||
break;
|
||||
@@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_GELU_ERF:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
|
||||
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
+222
-111
@@ -1,74 +1,93 @@
|
||||
#include "binbcast.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
#include "dpct/helper.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1,
|
||||
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) {
|
||||
auto element_id = it.get_global_id(0);
|
||||
auto global_range = it.get_global_range(0);
|
||||
for (; element_id < num_elements; element_id += global_range) {
|
||||
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[element_id] = val_to_store;
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne3;
|
||||
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne3;
|
||||
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
}
|
||||
|
||||
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13,
|
||||
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10,
|
||||
int s11, int s12, int s13, std::size_t num_dst_elements,
|
||||
const sycl::nd_item<1> & item_ct1) {
|
||||
auto calculate_logical_index =
|
||||
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> {
|
||||
std::array<int, 4> logical_index;
|
||||
#pragma unroll(4)
|
||||
for (int i = 3; i >= 0; i--) {
|
||||
logical_index[i] = element_id % dims[i];
|
||||
element_id /= dims[i];
|
||||
}
|
||||
return logical_index;
|
||||
};
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s00,*/ int s01, int s02, int s03,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides,
|
||||
const std::array<int, 4> & indices) __attribute__((always_inline))
|
||||
->std::size_t {
|
||||
std::size_t index = 0;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < 4; i++) {
|
||||
auto index_i = indices[i];
|
||||
if (indices[i] >= dims[i]) {
|
||||
index_i = indices[i] % dims[i];
|
||||
}
|
||||
index += strides[i] * index_i;
|
||||
}
|
||||
return index;
|
||||
};
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
auto element_id = item_ct1.get_global_id(0);
|
||||
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) {
|
||||
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id);
|
||||
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index);
|
||||
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
|
||||
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index);
|
||||
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
|
||||
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
|
||||
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
|
||||
dst[dst_index] = val_to_store;
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||
const int i1 = (i/ne0) % ne1;
|
||||
const int i0 = i % ne0;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
|
||||
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
struct bin_bcast_sycl {
|
||||
template <typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
|
||||
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
|
||||
@@ -77,73 +96,165 @@ template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
|
||||
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
|
||||
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
|
||||
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
|
||||
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims,
|
||||
const std::array<int64_t, 4> & dst_dims) -> bool {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (dst_dims[i] > src_dims[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
int nr0 = ne10 / ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
int nr3 = ne13/ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
// collapse dimensions until first broadcast dimension
|
||||
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
||||
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
||||
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
||||
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
||||
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb, cne);
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
int64_t ne0 = cne[0];
|
||||
int64_t ne1 = cne[1];
|
||||
int64_t ne2 = cne[2];
|
||||
int64_t ne3 = cne[3];
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
size_t nb0 = cnb[0];
|
||||
size_t nb1 = cnb[1];
|
||||
size_t nb2 = cnb[2];
|
||||
size_t nb3 = cnb[3];
|
||||
|
||||
// dst strides in number of elements
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
size_t nb00 = cnb0[0];
|
||||
size_t nb01 = cnb0[1];
|
||||
size_t nb02 = cnb0[2];
|
||||
size_t nb03 = cnb0[3];
|
||||
|
||||
// src1 strides in number of elements
|
||||
size_t s10 = nb10 / sizeof(src0_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
// src0 strides in number of elements
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) *
|
||||
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3);
|
||||
std::size_t local_range = 256;
|
||||
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range;
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) ||
|
||||
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 });
|
||||
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous;
|
||||
size_t s00 = nb00 / sizeof(src0_t);
|
||||
size_t s01 = nb01 / sizeof(src0_t);
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
if (! needs_broadcasting && all_contiguous) {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1,
|
||||
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it);
|
||||
});
|
||||
});
|
||||
GGML_UNUSED(s00);
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
||||
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
||||
|
||||
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
|
||||
sycl::range<3> block_dims(1, 1, 1);
|
||||
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims[1] = std::min<unsigned int>(
|
||||
ne1, block_size / (unsigned int)block_dims[2]);
|
||||
block_dims[0] = std::min(
|
||||
std::min<unsigned int>(
|
||||
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
||||
(unsigned int)block_dims[1]),
|
||||
64U);
|
||||
|
||||
sycl::range<3> block_nums(
|
||||
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
||||
(ne1 + block_dims[1] - 1) / block_dims[1],
|
||||
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
||||
|
||||
if (block_nums[0] > 65535) {
|
||||
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
||||
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
|
||||
s03, s11, s12, s13, item_ct1);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
||||
exceed the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if
|
||||
needed.
|
||||
*/
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3740,7 +3740,7 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend,
|
||||
GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
|
||||
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
|
||||
data, (const char *)tensor->data + offset, size).wait()));
|
||||
data, (const char *)tensor->data + offset, size)));
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
@@ -3760,7 +3760,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend,
|
||||
*/
|
||||
const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
|
||||
dst->data, src->data, ggml_nbytes(dst)).wait()));
|
||||
dst->data, src->data, ggml_nbytes(dst))));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -2804,23 +2804,29 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
pipeline_robustness = true;
|
||||
} else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
device->subgroup_size_control = true;
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
||||
device->coopmat_support = true;
|
||||
device->coopmat_m = 0;
|
||||
device->coopmat_n = 0;
|
||||
device->coopmat_k = 0;
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
device->integer_dot_product = true;
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||
bfloat16_support = true;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4670,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
}
|
||||
|
||||
if (src->type == to) {
|
||||
// Copy two or four bytes at a time, depending on block size.
|
||||
// For quantized types, we scale by block size/type size. But
|
||||
// this path is also used for bf16->bf16 for example, where the
|
||||
// type size must be exactly 2 or 4.
|
||||
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
|
||||
if ((ggml_type_size(src->type) % 4) == 0) {
|
||||
return ctx->device->pipeline_contig_cpy_f32_f32;
|
||||
} else {
|
||||
return ctx->device->pipeline_contig_cpy_f16_f16;
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -6731,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
{
|
||||
const uint32_t ne = ggml_nelements(dst);
|
||||
uint32_t ne = ggml_nelements(dst);
|
||||
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
// Convert from number of logical elements to 2- or 4-byte units.
|
||||
ne /= ggml_blck_size(src0->type);
|
||||
if ((ggml_type_size(src0->type) % 4) == 0) {
|
||||
ne *= ggml_type_size(src0->type) / 4;
|
||||
} else {
|
||||
ne *= ggml_type_size(src0->type) / 2;
|
||||
}
|
||||
}
|
||||
if (ne > 262144) {
|
||||
elements = { 512, 512, CEIL_DIV(ne, 262144) };
|
||||
} else if (ne > 512) {
|
||||
@@ -7281,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
uint32_t ne = (uint32_t)ggml_nelements(src0);
|
||||
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||
// Convert from number of logical elements to 2- or 4-byte units.
|
||||
ne /= ggml_blck_size(src0->type);
|
||||
if ((ggml_type_size(src0->type) % 4) == 0) {
|
||||
ne *= ggml_type_size(src0->type) / 4;
|
||||
} else {
|
||||
ne *= ggml_type_size(src0->type) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
ne,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
@@ -9264,8 +9303,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
|
||||
try {
|
||||
ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
|
||||
} catch (vk::SystemError& e) {
|
||||
std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl;
|
||||
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
|
||||
GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
|
||||
// fallback to cpu buffer
|
||||
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
||||
}
|
||||
@@ -9867,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// We can handle copying from a type to the same type if it's
|
||||
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
|
||||
// so the type/block size must be a multiple of 4.
|
||||
if (src0_type == src1_type &&
|
||||
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
|
||||
(ggml_type_size(src0_type) % 2) == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_REPEAT:
|
||||
|
||||
@@ -219,10 +219,13 @@ class Keys:
|
||||
TYPE = "adapter.type"
|
||||
LORA_ALPHA = "adapter.lora.alpha"
|
||||
|
||||
class ClipVision:
|
||||
class Clip:
|
||||
PROJECTOR_TYPE = "clip.projector_type"
|
||||
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
||||
HAS_AUDIO_ENCODER = "clip.has_audio_encoder"
|
||||
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
|
||||
|
||||
class ClipVision:
|
||||
IMAGE_SIZE = "clip.vision.image_size"
|
||||
PATCH_SIZE = "clip.vision.patch_size"
|
||||
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
||||
@@ -243,19 +246,33 @@ class Keys:
|
||||
class Projector:
|
||||
SCALE_FACTOR = "clip.vision.projector.scale_factor"
|
||||
|
||||
class ClipAudio:
|
||||
NUM_MEL_BINS = "clip.audio.num_mel_bins"
|
||||
EMBEDDING_LENGTH = "clip.audio.embedding_length"
|
||||
FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length"
|
||||
PROJECTION_DIM = "clip.audio.projection_dim"
|
||||
BLOCK_COUNT = "clip.audio.block_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon"
|
||||
|
||||
class Projector:
|
||||
STACK_FACTOR = "clip.audio.projector.stack_factor"
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
|
||||
|
||||
class GGUFType:
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
CLIP_VISION = "clip-vision"
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
MMPROJ = "mmproj" # dummy, unused for now
|
||||
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
CLIP_VISION = auto() # dummy arch for clip.cpp
|
||||
MMPROJ = auto() # dummy arch for clip.cpp
|
||||
LLAMA = auto()
|
||||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
@@ -514,10 +531,28 @@ class MODEL_TENSOR(IntEnum):
|
||||
V_RESMPL_QUERY = auto() # minicpmv
|
||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||
# audio (mtmd)
|
||||
A_ENC_EMBD_POS = auto()
|
||||
A_ENC_CONV1D = auto()
|
||||
A_PRE_NORM = auto()
|
||||
A_POST_NORM = auto()
|
||||
A_ENC_ATTN_Q = auto()
|
||||
A_ENC_ATTN_K = auto()
|
||||
A_ENC_ATTN_V = auto()
|
||||
A_ENC_INPUT_NORM = auto()
|
||||
A_ENC_OUTPUT = auto()
|
||||
A_ENC_OUTPUT_NORM = auto()
|
||||
A_ENC_FFN_UP = auto()
|
||||
A_ENC_FFN_GATE = auto()
|
||||
A_ENC_FFN_DOWN = auto()
|
||||
A_MMPROJ = auto()
|
||||
A_MMPROJ_FC = auto()
|
||||
A_MM_NORM_PRE = auto()
|
||||
A_MM_NORM_MID = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
|
||||
MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.LLAMA4: "llama4",
|
||||
MODEL_ARCH.DECI: "deci",
|
||||
@@ -776,10 +811,28 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||
# audio (mtmd)
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
|
||||
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
|
||||
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
|
||||
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
|
||||
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
|
||||
MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.CLIP_VISION: [
|
||||
MODEL_ARCH.MMPROJ: [
|
||||
MODEL_TENSOR.V_MMPROJ,
|
||||
MODEL_TENSOR.V_MMPROJ_FC,
|
||||
MODEL_TENSOR.V_MMPROJ_MLP,
|
||||
@@ -819,6 +872,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||
# audio
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.A_ENC_CONV1D,
|
||||
MODEL_TENSOR.A_PRE_NORM,
|
||||
MODEL_TENSOR.A_POST_NORM,
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.A_ENC_ATTN_K,
|
||||
MODEL_TENSOR.A_ENC_ATTN_V,
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT,
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.A_ENC_FFN_UP,
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE,
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.A_MMPROJ,
|
||||
MODEL_TENSOR.A_MMPROJ_FC,
|
||||
MODEL_TENSOR.A_MM_NORM_PRE,
|
||||
MODEL_TENSOR.A_MM_NORM_MID,
|
||||
],
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2186,7 +2257,9 @@ class VisionProjectorType:
|
||||
LLAMA4 = "llama4"
|
||||
QWEN2VL = "qwen2vl_merger"
|
||||
QWEN25VL = "qwen2.5vl_merger"
|
||||
ULTRAVOX = "ultravox"
|
||||
INTERNVL = "internvl"
|
||||
QWEN2A = "qwen2a" # audio
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -896,7 +896,7 @@ class GGUFWriter:
|
||||
def add_remove_extra_whitespaces(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value)
|
||||
|
||||
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
|
||||
def add_precompiled_charsmap(self, charsmap: bytes) -> None:
|
||||
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
|
||||
|
||||
def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
||||
@@ -936,12 +936,18 @@ class GGUFWriter:
|
||||
|
||||
# for vision models
|
||||
|
||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value)
|
||||
|
||||
def add_clip_has_audio_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value)
|
||||
|
||||
def add_clip_projector_type(self, value: str) -> None:
|
||||
self.add_string(Keys.Clip.PROJECTOR_TYPE, value)
|
||||
|
||||
def add_vision_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
|
||||
|
||||
def add_vision_has_vision_encoder(self, value: bool) -> None:
|
||||
self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
|
||||
|
||||
def add_vision_patch_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
|
||||
|
||||
@@ -957,9 +963,6 @@ class GGUFWriter:
|
||||
def add_vision_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_vision_projector_type(self, value: str) -> None:
|
||||
self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
|
||||
|
||||
def add_vision_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
@@ -987,6 +990,32 @@ class GGUFWriter:
|
||||
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||
|
||||
# audio models
|
||||
|
||||
def add_audio_projection_dim(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value)
|
||||
|
||||
def add_audio_embedding_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value)
|
||||
|
||||
def add_audio_feed_forward_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value)
|
||||
|
||||
def add_audio_block_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value)
|
||||
|
||||
def add_audio_head_count(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value)
|
||||
|
||||
def add_audio_attention_layernorm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value)
|
||||
|
||||
def add_audio_num_mel_bins(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value)
|
||||
|
||||
def add_audio_stack_factor(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
|
||||
|
||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||
pack_prefix = ''
|
||||
if not skip_pack_prefix:
|
||||
|
||||
@@ -1110,6 +1110,72 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
||||
),
|
||||
|
||||
# audio (mtmd)
|
||||
|
||||
MODEL_TENSOR.A_ENC_EMBD_POS: (
|
||||
"audio_tower.embed_positions", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_CONV1D: (
|
||||
"audio_tower.conv{bid}", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_PRE_NORM: (),
|
||||
|
||||
MODEL_TENSOR.A_POST_NORM: (
|
||||
"audio_tower.layer_norm", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_Q: (
|
||||
"audio_tower.layers.{bid}.self_attn.q_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_K: (
|
||||
"audio_tower.layers.{bid}.self_attn.k_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_ATTN_V: (
|
||||
"audio_tower.layers.{bid}.self_attn.v_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_INPUT_NORM: (
|
||||
"audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUTPUT: (
|
||||
"audio_tower.layers.{bid}.self_attn.out_proj", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_OUTPUT_NORM: (
|
||||
"audio_tower.layers.{bid}.final_layer_norm", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_UP: (
|
||||
"audio_tower.layers.{bid}.fc1", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_GATE: (),
|
||||
|
||||
MODEL_TENSOR.A_ENC_FFN_DOWN: (
|
||||
"audio_tower.layers.{bid}.fc2", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ: (
|
||||
"audio.multi_modal_projector.linear_{bid}", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MMPROJ_FC: (
|
||||
"audio.multi_modal_projector.linear", # qwen2audio
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_PRE: (
|
||||
"audio.multi_modal_projector.ln_pre", # ultravox
|
||||
),
|
||||
|
||||
MODEL_TENSOR.A_MM_NORM_MID: (
|
||||
"audio.multi_modal_projector.ln_mid", # ultravox
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
||||
@@ -471,6 +471,7 @@ extern "C" {
|
||||
LLAMA_API int64_t llama_time_us(void);
|
||||
|
||||
LLAMA_API size_t llama_max_devices(void);
|
||||
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||
|
||||
LLAMA_API bool llama_supports_mmap (void);
|
||||
LLAMA_API bool llama_supports_mlock (void);
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,112 @@
|
||||
ied 4 ½ months
|
||||
__ggml_vocab_test__
|
||||
Führer
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello world
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World
|
||||
__ggml_vocab_test__
|
||||
Hello World!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
Hello, world!
|
||||
__ggml_vocab_test__
|
||||
this is 🦙.cpp
|
||||
__ggml_vocab_test__
|
||||
w048 7tuijk dsdfhu
|
||||
__ggml_vocab_test__
|
||||
нещо на Български
|
||||
__ggml_vocab_test__
|
||||
កាន់តែពិសេសអាចខលចេញ
|
||||
__ggml_vocab_test__
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
Hello
|
||||
Hello
|
||||
__ggml_vocab_test__
|
||||
(
|
||||
__ggml_vocab_test__
|
||||
|
||||
=
|
||||
__ggml_vocab_test__
|
||||
' era
|
||||
__ggml_vocab_test__
|
||||
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
|
||||
__ggml_vocab_test__
|
||||
!!!!!!
|
||||
__ggml_vocab_test__
|
||||
3
|
||||
__ggml_vocab_test__
|
||||
33
|
||||
__ggml_vocab_test__
|
||||
333
|
||||
__ggml_vocab_test__
|
||||
3333
|
||||
__ggml_vocab_test__
|
||||
33333
|
||||
__ggml_vocab_test__
|
||||
333333
|
||||
__ggml_vocab_test__
|
||||
3333333
|
||||
__ggml_vocab_test__
|
||||
33333333
|
||||
__ggml_vocab_test__
|
||||
333333333
|
||||
__ggml_vocab_test__
|
||||
Cửa Việt
|
||||
__ggml_vocab_test__
|
||||
discards
|
||||
__ggml_vocab_test__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
🚀 (normal) 😶🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
|
||||
__ggml_vocab_test__
|
||||
@@ -0,0 +1,46 @@
|
||||
17 297 201 78660 21775
|
||||
72805 4097 56
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
35378 8999
|
||||
35378 8999
|
||||
35378 6661
|
||||
35378 6661
|
||||
35378 6661 38
|
||||
35378 4 8999 38
|
||||
35378 4 8999 38
|
||||
903 83 6 3 5 238 6366
|
||||
148 7709 1019 361 458 134362 104 7 71 420 1132
|
||||
14271 29 117152
|
||||
6 149561 78270 48967 64254 7616 81705
|
||||
6 247206 15 33176 16 6 247442 6 3 15755 15 144227 8705 18255 40292 158 4460 33 27686 16 6 142325 15 191 538 28 121505 450 1556 6863 10002 47 1098 16
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378
|
||||
35378 35378
|
||||
15
|
||||
2203
|
||||
242 1615
|
||||
35378 4 113 25 5584 38 11249 621 398 6 201344 705 23638 213 9007 133 1879 2681 2592 135224 1906 6087
|
||||
6 90827
|
||||
138
|
||||
3912
|
||||
6 66000
|
||||
138 66000
|
||||
3912 66000
|
||||
6 66000 66000
|
||||
138 66000 66000
|
||||
3912 66000 66000
|
||||
6 66000 66000 66000
|
||||
199152 3763
|
||||
17116 99397
|
||||
6 247206 15 33176 16 6 247442 6 3 15755 15 144227 8705 18255 40292 158 4460 33 27686 16 6 142325 6 3 138 3912 6 66000 138 66000 3912 66000 6 66000 66000 138 66000 66000 3912 66000 66000 80308 1031 5 363 138 27 363 6 149561 78270 48967 201344 705 23638 213 9007 133 1879 2681 2592 135224 1906 6087 6 110405 1369 69112 69112 69112 14271 29 117152 5106 4765 4765 1135 164721 164721 164721 58 58 58 58 2551 90827 32 85908 87 25 272 2809 242 18 18345 764 25 7 2685 4 242 11766 398 9077 32 242 594 959 9077 87 25 1181 3249 442 4 242 397 398 1884 3060 26156 32 1401 25 26455 10 25 141 866
|
||||
@@ -0,0 +1,62 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- messages[0]['content'] }}
|
||||
{%- else %}
|
||||
{{- '' }}
|
||||
{%- endif %}
|
||||
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" and not message.tool_calls %}
|
||||
{%- set content = message.content %}
|
||||
{%- if not loop.last %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set content = message.content %}
|
||||
{%- if not loop.last %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{{- '<|im_start|>' + message.role }}
|
||||
{%- if message.content %}
|
||||
{{- '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||
{%- endif %}
|
||||
@@ -0,0 +1,85 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set content = message.content %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in message.content %}
|
||||
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- message.content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
@@ -19,4 +19,6 @@ These templates can be updated with the following commands:
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
||||
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
|
||||
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
|
||||
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
|
||||
```
|
||||
@@ -1,3 +1,7 @@
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.2.1
|
||||
torch~=2.2.1; platform_machine != "s390x"
|
||||
|
||||
# torch s390x packages can only be found from nightly builds
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly
|
||||
torch>=0.0.0.dev0; platform_machine == "s390x"
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch~=2.2.1
|
||||
torch~=2.2.1; platform_machine != "s390x"
|
||||
|
||||
# torch s390x packages can only be found from nightly builds
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly
|
||||
torch>=0.0.0.dev0; platform_machine == "s390x"
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
-r ./requirements-convert_hf_to_gguf.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# torch s390x packages can only be found from nightly builds
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
|
||||
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
|
||||
|
||||
./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
|
||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
|
||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
|
||||
|
||||
@@ -205,6 +206,7 @@ def run(
|
||||
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
|
||||
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
|
||||
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
|
||||
chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
|
||||
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
|
||||
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
|
||||
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
|
||||
@@ -229,6 +231,12 @@ def run(
|
||||
# n_ctx = 8192
|
||||
n_ctx = 2048
|
||||
|
||||
if model is None:
|
||||
if hf is not None:
|
||||
model = hf.split("/")[-1]
|
||||
elif ollama is not None:
|
||||
model = ollama
|
||||
|
||||
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
|
||||
|
||||
with output.open('a' if append else 'w') as output_file:
|
||||
@@ -320,6 +328,7 @@ def run(
|
||||
server.model_hf_repo = hf
|
||||
server.model_hf_file = None
|
||||
server.chat_template = chat_template
|
||||
server.chat_template_file = chat_template_file
|
||||
server.server_path = server_path
|
||||
if port is not None:
|
||||
server.server_port = port
|
||||
@@ -335,6 +344,7 @@ def run(
|
||||
temp=t,
|
||||
output_kwargs=dict(
|
||||
chat_template=chat_template,
|
||||
chat_template_file=chat_template_file,
|
||||
),
|
||||
request_kwargs=dict(
|
||||
ignore_chat_grammar=ignore_chat_grammar,
|
||||
@@ -355,6 +365,7 @@ def run(
|
||||
temp=t,
|
||||
output_kwargs=dict(
|
||||
chat_template=None,
|
||||
chat_template_file=None,
|
||||
),
|
||||
request_kwargs=dict(
|
||||
model=ollama,
|
||||
|
||||
@@ -25,7 +25,11 @@ llama_context::llama_context(
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
|
||||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
|
||||
}
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
||||
@@ -1 +1,5 @@
|
||||
#include "llama-cparams.h"
|
||||
|
||||
size_t llama_max_parallel_sequences(void) {
|
||||
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
||||
|
||||
struct llama_cparams {
|
||||
uint32_t n_ctx; // context size used during inference
|
||||
uint32_t n_batch;
|
||||
|
||||
+12
-2
@@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||
grammar.awaiting_trigger = false;
|
||||
// get from the first match to the end of the string
|
||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
||||
// get from the first matched capturing group to the end of the string
|
||||
size_t start = std::string::npos;
|
||||
for (auto i = 1u; i < match.size(); i++) {
|
||||
if (match.length(i) > 0) {
|
||||
start = match.position(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (start == std::string::npos) {
|
||||
start = match.position(0);
|
||||
}
|
||||
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||
grammar.trigger_buffer.clear();
|
||||
llama_grammar_accept_str(grammar, constrained_str);
|
||||
|
||||
+4
-4
@@ -1287,6 +1287,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
@@ -1367,10 +1371,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
|
||||
+17
-1
@@ -2,6 +2,22 @@
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa_any() const {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (swa_layers[il]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_head(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_head_arr[il];
|
||||
@@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
||||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1));
|
||||
return swa_layers[il];
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
+23
-14
@@ -102,20 +102,12 @@ struct llama_hparams {
|
||||
|
||||
// Sliding Window Attention (SWA)
|
||||
llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA)
|
||||
// by default n == 1, all layers are dense
|
||||
// note that if n_swa_pattern == 0, all layers are SWA
|
||||
// example: n_swa_pattern = 3
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// il == 4: swa
|
||||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
// the size of the sliding window (0 - no SWA)
|
||||
uint32_t n_swa = 0;
|
||||
// if swa_layers[il] == true, then layer il is SWA
|
||||
// if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
|
||||
// by default, all layers are dense
|
||||
std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
@@ -153,6 +145,23 @@ struct llama_hparams {
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
|
||||
|
||||
// this value n_pattern means that every nth layer is dense (i.e. non-SWA)
|
||||
// note that if n_pattern == 0, all layers are SWA
|
||||
// if n_pattern == 1, all layers are dense
|
||||
// example: n_pattern = 3
|
||||
// il == 0: swa
|
||||
// il == 1: swa
|
||||
// il == 2: dense
|
||||
// il == 3: swa
|
||||
// il == 4: swa
|
||||
// il == 5: dense
|
||||
// il == 6: swa
|
||||
// etc ...
|
||||
void set_swa_pattern(uint32_t n_pattern);
|
||||
|
||||
// return true if one of the layers is SWA
|
||||
bool is_swa_any() const;
|
||||
|
||||
uint32_t n_head(uint32_t il = 0) const;
|
||||
|
||||
uint32_t n_head_kv(uint32_t il = 0) const;
|
||||
|
||||
+162
-220
@@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
};
|
||||
|
||||
head = 0;
|
||||
size = kv_size;
|
||||
used = 0;
|
||||
|
||||
cells.resize(kv_size);
|
||||
|
||||
@@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::clear() {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
}
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
used = 0;
|
||||
|
||||
for (auto & buf : bufs) {
|
||||
ggml_backend_buffer_clear(buf.get(), 0);
|
||||
@@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() {
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
p0 = 0;
|
||||
@@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells[i].is_empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
|
||||
if (new_head == size) {
|
||||
new_head = i;
|
||||
}
|
||||
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != size && new_head < head) {
|
||||
if (new_head != cells.size() && new_head < head) {
|
||||
head = new_head;
|
||||
}
|
||||
|
||||
@@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
// otherwise, this is the KV of a Transformer-like model
|
||||
head = 0;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
cells[i].seq_id.insert(seq_id_dst);
|
||||
if (cells.seq_has(i, seq_id_src)) {
|
||||
cells.seq_add(i, seq_id_dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (!cells[i].has_seq_id(seq_id)) {
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
|
||||
if (new_head == size){
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_keep(i, seq_id)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
} else {
|
||||
cells[i].seq_id.clear();
|
||||
cells[i].seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
if (new_head != size && new_head < head) {
|
||||
if (new_head != cells.size() && new_head < head) {
|
||||
head = new_head;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (delta == 0) {
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t new_head = size;
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
p0 = 0;
|
||||
@@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
p1 = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
// If there is no range then return early to avoid looping over the
|
||||
// If there is no range then return early to avoid looping over all cells.
|
||||
if (p0 == p1) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
has_shift = true;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cells[i].pos += delta;
|
||||
cells[i].delta += delta;
|
||||
|
||||
if (cells[i].pos < 0) {
|
||||
if (!cells[i].is_empty()) {
|
||||
used--;
|
||||
}
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
if (new_head == size) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
if (cells.pos_add(i, shift)) {
|
||||
if (new_head == cells.size()) {
|
||||
new_head = i;
|
||||
}
|
||||
}
|
||||
@@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
|
||||
// If we freed up a slot, set head to it so searching can start there.
|
||||
// Otherwise we just start the next search from the beginning.
|
||||
head = new_head != size ? new_head : 0;
|
||||
head = new_head != cells.size() ? new_head : 0;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
@@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
||||
has_shift = true;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.pos_in(i, p0, p1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
llama_pos p_old = cells[i].pos;
|
||||
cells[i].pos /= d;
|
||||
cells[i].delta += cells[i].pos - p_old;
|
||||
}
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
cells.pos_div(i, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
result = std::min(result, cells[i].pos);
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::min(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
result = std::max(result, cells[i].pos);
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (cells.seq_has(i, seq_id)) {
|
||||
result = std::max(result, cells.pos_get(i));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::restore() {
|
||||
for (const auto & [id, cell] : recovery.cells) {
|
||||
// TODO: move to new `struct kv_cells`
|
||||
const bool is_empty0 = cells[id].is_empty();
|
||||
const bool is_empty1 = cell.is_empty();
|
||||
|
||||
if (!is_empty0 && is_empty1) {
|
||||
used--;
|
||||
} else if (is_empty0 && !is_empty1) {
|
||||
used++;
|
||||
}
|
||||
|
||||
cells[id] = cell;
|
||||
for (auto & state : recovery.states) {
|
||||
cells.set(state.i, state.cells);
|
||||
}
|
||||
|
||||
recovery.clear();
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::commit() {
|
||||
if (recovery.cells.empty()) {
|
||||
if (recovery.states.empty()) {
|
||||
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
||||
return;
|
||||
@@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
|
||||
auto * sched = lctx.get_sched();
|
||||
|
||||
if (has_shift) {
|
||||
if (cells.get_has_shift()) {
|
||||
if (!get_can_shift()) {
|
||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||
}
|
||||
@@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
need_reserve = true;
|
||||
}
|
||||
|
||||
{
|
||||
has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
cells[i].delta = 0;
|
||||
}
|
||||
}
|
||||
cells.reset_shift();
|
||||
}
|
||||
|
||||
if (do_defrag) {
|
||||
@@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
|
||||
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > thold) {
|
||||
@@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_full() {
|
||||
n = size;
|
||||
n = cells.size();
|
||||
|
||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||
@@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head > used + 2*ubatch.n_tokens) {
|
||||
if (head > cells.get_used() + 2*ubatch.n_tokens) {
|
||||
head = 0;
|
||||
}
|
||||
|
||||
// otherwise, one cell per token.
|
||||
|
||||
if (n_tokens > size) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
||||
std::string ss;
|
||||
if (n_swa > 0) {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].pos == -1) {
|
||||
if (cells.is_empty(i)) {
|
||||
ss += '.';
|
||||
} else {
|
||||
ss += std::to_string(*cells[i].seq_id.begin());
|
||||
ss += 'x';
|
||||
}
|
||||
if (i%256 == 255) {
|
||||
ss += '\n';
|
||||
@@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
while (true) {
|
||||
if (head + n_tokens > size) {
|
||||
n_tested += size - head;
|
||||
if (head + n_tokens > cells.size()) {
|
||||
n_tested += cells.size() - head;
|
||||
head = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool found = true;
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
if (cells[head + i].pos >= 0) {
|
||||
// TODO: improve to accept cells that are masked by the SWA
|
||||
if (!cells.is_empty(head + i)) {
|
||||
found = false;
|
||||
head += i + 1;
|
||||
n_tested += i + 1;
|
||||
@@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (n_tested >= size) {
|
||||
if (n_tested >= cells.size()) {
|
||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
// remember the original state
|
||||
if (recovery.cells.find(head + i) == recovery.cells.end()) {
|
||||
recovery.cells[head + i] = cells[head + i];
|
||||
}
|
||||
// store the old state of the cells in the recovery stack
|
||||
recovery.states.push_back({head, cells.cp(head, n_tokens)});
|
||||
|
||||
cells[head + i].pos = ubatch.pos[i];
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
cells.pos_set(head + i, ubatch.pos[i]);
|
||||
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
||||
cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
|
||||
cells.seq_add(head + i, ubatch.seq_id[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
used += n_tokens;
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||
|
||||
#ifdef FIND_SLOT_DEBUG
|
||||
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
||||
@@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const {
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
return size;
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
||||
@@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
|
||||
|
||||
int n_attended = 0;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
const llama_pos p0 = cells[i].pos;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.seq_has(i, seq_id)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_pos p0 = cells.pos_get(i);
|
||||
|
||||
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
||||
n_attended++;
|
||||
}
|
||||
|
||||
if (is_masked_swa(p0, pmax)) {
|
||||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (cells[i].is_empty()) {
|
||||
// keep count of the number of used cells
|
||||
if (cells[i].pos >= 0) {
|
||||
used--;
|
||||
}
|
||||
|
||||
cells[i].pos = -1;
|
||||
}
|
||||
cells.seq_rm(i, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const llama_pos p0 = cells[i].pos;
|
||||
float f = 0.0f;
|
||||
|
||||
bool masked = false;
|
||||
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells[i].has_seq_id(seq_id));
|
||||
if (cells.is_empty(i)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(i);
|
||||
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(i, seq_id));
|
||||
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
|
||||
float f = 0.0f;
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
} else if (hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
@@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
data[i] = cells[i].delta;
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
||||
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
||||
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, size,
|
||||
n_embd_head_k, n_head_kv, cells.size(),
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
@@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, size),
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(layer.v->type, size),
|
||||
ggml_row_size(layer.v->type, cells.size()),
|
||||
ggml_row_size(layer.v->type, id));
|
||||
}
|
||||
|
||||
@@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cell_max();
|
||||
const uint32_t n_used = used;
|
||||
const uint32_t n_used = cells.get_used();
|
||||
|
||||
assert(n_used <= n_kv);
|
||||
|
||||
@@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
ids.resize(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||
const auto & cell0 = cells[i0];
|
||||
|
||||
if (!cell0.is_empty()) {
|
||||
if (!cells.is_empty(i0)) {
|
||||
ids[i0] = i0;
|
||||
|
||||
continue;
|
||||
@@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
uint32_t nh = 1;
|
||||
|
||||
// determine the size of the hole
|
||||
while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
|
||||
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
||||
nh++;
|
||||
}
|
||||
|
||||
@@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
|
||||
// starting from the end, find nh non-empty cells
|
||||
for (; is > i0; --is) {
|
||||
const auto & cell1 = cells[is];
|
||||
|
||||
if (cell1.is_empty() || ids[is] != n_kv) {
|
||||
if (cells.is_empty(is) || ids[is] != n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
@@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
ids[i1] = i0 + nf;
|
||||
|
||||
// move the cell meta data
|
||||
cells[i0 + nf] = cell1;
|
||||
cells.mv(i1, i0 + nf);
|
||||
|
||||
// clear the old cell and move the head there
|
||||
cell1 = kv_cell();
|
||||
head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
@@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||
for (uint32_t i = size; i > 0; --i) {
|
||||
const kv_cell & cell = cells[i - 1];
|
||||
|
||||
if (cell.pos >= 0 && !cell.is_empty()) {
|
||||
for (uint32_t i = cells.size(); i > 0; --i) {
|
||||
if (!cells.is_empty(i - 1)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
@@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const {
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
if (p0 < 0) {
|
||||
return true;
|
||||
}
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
@@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
||||
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = size;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||
uint32_t cell_range_begin = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
||||
++cell_count;
|
||||
if (cell_range_begin == size) {
|
||||
if (cell_range_begin == cells.size()) {
|
||||
cell_range_begin = i;
|
||||
}
|
||||
} else {
|
||||
if (cell_range_begin != size) {
|
||||
if (cell_range_begin != cells.size()) {
|
||||
cell_ranges.emplace_back(cell_range_begin, i);
|
||||
cell_range_begin = size;
|
||||
cell_range_begin = cells.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cell_range_begin != size) {
|
||||
cell_ranges.emplace_back(cell_range_begin, size);
|
||||
|
||||
if (cell_range_begin != cells.size()) {
|
||||
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
||||
}
|
||||
|
||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||
@@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
const llama_pos pos = cell.pos;
|
||||
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
||||
std::vector<llama_seq_id> seq_ids;
|
||||
|
||||
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
||||
if (cur == seq_id || seq_id == -1) {
|
||||
if (cells.seq_has(i, cur)) {
|
||||
seq_ids.push_back(cur);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos pos = cells.pos_get(i);
|
||||
const uint32_t n_seq_id = seq_ids.size();
|
||||
|
||||
io.write(&pos, sizeof(pos));
|
||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id) {
|
||||
for (auto seq_id : cell.seq_id) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
for (const auto & seq_id : seq_ids) {
|
||||
io.write(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||
}
|
||||
} else {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = size;
|
||||
const uint32_t kv_size = cells.size();
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
@@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
if (n_seq_id != 0) {
|
||||
if (n_seq_id != 1) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i] = &dest_seq_id;
|
||||
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||
{
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
batch.n_seq_id[i] = n_seq_id;
|
||||
batch.seq_id[i] = &dest_seq_id;
|
||||
}
|
||||
|
||||
if (!find_slot(batch)) {
|
||||
@@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head + cell_count <= size);
|
||||
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(head + cell_count <= cells.size());
|
||||
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
|
||||
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
|
||||
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
|
||||
} else {
|
||||
// whole KV cache restore
|
||||
|
||||
if (cell_count > size) {
|
||||
if (cell_count > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
clear();
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
||||
io.read_to(&pos, sizeof(pos));
|
||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||
|
||||
cell.pos = pos;
|
||||
cells.pos_set(i, pos);
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
llama_seq_id seq_id;
|
||||
@@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
return false;
|
||||
}
|
||||
|
||||
cell.seq_id.insert(seq_id);
|
||||
cells.seq_add(i, seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
head = 0;
|
||||
used = cell_count;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||
return false;
|
||||
}
|
||||
if (cell_count > size) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
||||
if (cell_count > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
||||
return false;
|
||||
}
|
||||
if (this->v_trans != (bool) v_trans) {
|
||||
@@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
||||
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
||||
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
}
|
||||
}
|
||||
@@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_swa ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
kv_base->seq_add(seq_id, p0, p1, delta);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, delta);
|
||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
@@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
||||
if (delta == 0) {
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
||||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos += delta;
|
||||
cell.pos += shift;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+22
-31
@@ -4,6 +4,7 @@
|
||||
#include "llama-io.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-kv-cells.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
@@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
|
||||
virtual void defrag_sched(float thold) = 0;
|
||||
|
||||
// simulate full cache, used for allocating worst-case compute buffers
|
||||
// TODO: remove
|
||||
virtual void set_full() = 0;
|
||||
|
||||
//
|
||||
@@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
|
||||
//
|
||||
|
||||
// =============================================================================================================
|
||||
// TODO: refactor and simplify this
|
||||
// TODO: refactor and simplify this [TAG: KV_API]
|
||||
|
||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||
|
||||
@@ -121,7 +123,7 @@ public:
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
@@ -159,7 +161,7 @@ public:
|
||||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_n() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
@@ -180,26 +182,6 @@ private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
|
||||
// TODO: replace with bitset uint64_t
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
||||
bool has_seq_id(const llama_seq_id & id) const {
|
||||
return seq_id.find(id) != seq_id.end();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
@@ -209,15 +191,13 @@ private:
|
||||
ggml_tensor * v;
|
||||
};
|
||||
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
|
||||
|
||||
// computed before each graph build
|
||||
// TODO: cells should start to maintain this value dynamically based on the edits
|
||||
uint32_t n = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
@@ -233,19 +213,29 @@ private:
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
|
||||
llama_kv_cells_unified cells;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
// recovery information used to restore the KV cells to their original state in case of a failure
|
||||
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
||||
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
||||
struct {
|
||||
void clear() {
|
||||
cells.clear();
|
||||
states.clear();
|
||||
}
|
||||
|
||||
std::unordered_map<uint32_t, kv_cell> cells;
|
||||
struct state {
|
||||
uint32_t i;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
};
|
||||
|
||||
// stack with the partial states before each ubatch
|
||||
std::vector<state> states;
|
||||
} recovery;
|
||||
|
||||
// defrag
|
||||
@@ -257,6 +247,7 @@ private:
|
||||
bool defrag_prepare(int32_t n_max_nodes);
|
||||
|
||||
// find how many cells are currently in use
|
||||
// TODO: optimize
|
||||
uint32_t cell_max() const;
|
||||
|
||||
size_t total_size() const;
|
||||
@@ -325,7 +316,7 @@ public:
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
@@ -431,7 +422,7 @@ public:
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include <bitset>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
class llama_kv_cells_unified {
|
||||
public:
|
||||
void reset() {
|
||||
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||
pos[i] = -1;
|
||||
shift[i] = 0;
|
||||
seq[i].reset();
|
||||
}
|
||||
|
||||
used = 0;
|
||||
has_shift = false;
|
||||
}
|
||||
|
||||
void reset_shift() {
|
||||
has_shift = false;
|
||||
|
||||
for (uint32_t i = 0; i < shift.size(); ++i) {
|
||||
shift[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t size() const {
|
||||
return pos.size();
|
||||
}
|
||||
|
||||
void resize(uint32_t n) {
|
||||
pos.resize(n);
|
||||
shift.resize(n);
|
||||
seq.resize(n);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
bool is_empty(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
||||
|
||||
return pos[i] == -1;
|
||||
}
|
||||
|
||||
uint32_t get_used() const {
|
||||
return used;
|
||||
}
|
||||
|
||||
bool get_has_shift() const {
|
||||
return has_shift;
|
||||
}
|
||||
|
||||
// move cell isrc to idst (used during defrag)
|
||||
void mv(uint32_t isrc, uint32_t idst) {
|
||||
assert(isrc < pos.size());
|
||||
assert(idst < pos.size());
|
||||
|
||||
pos [idst] = pos [isrc];
|
||||
shift[idst] = shift[isrc];
|
||||
seq [idst] = seq [isrc];
|
||||
|
||||
pos [isrc] = -1;
|
||||
shift[isrc] = 0;
|
||||
seq [isrc].reset();
|
||||
}
|
||||
|
||||
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||
assert(i + n <= pos.size());
|
||||
|
||||
llama_kv_cells_unified res;
|
||||
|
||||
res.resize(n);
|
||||
|
||||
for (uint32_t j = 0; j < n; ++j) {
|
||||
res.pos[j] = pos[i + j];
|
||||
res.seq[j] = seq[i + j];
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||
assert(i + other.pos.size() <= pos.size());
|
||||
|
||||
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||
used++;
|
||||
}
|
||||
|
||||
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||
used--;
|
||||
}
|
||||
|
||||
pos[i + j] = other.pos[j];
|
||||
seq[i + j] = other.seq[j];
|
||||
|
||||
assert(shift[i + j] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// note: call only if the cell has seq_id
|
||||
// return true if the cell becomes empty
|
||||
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(seq[i].test(seq_id));
|
||||
assert(pos[i] != -1);
|
||||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
||||
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
|
||||
if (seq[i].test(seq_id)) {
|
||||
seq[i].reset();
|
||||
seq[i].set(seq_id);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
if (seq[i].any()) {
|
||||
seq[i].reset();
|
||||
pos[i] = -1;
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
assert(pos[i] == -1);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||
assert(i < pos.size());
|
||||
assert(seq_id >= 0);
|
||||
|
||||
return seq[i].test(seq_id);
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty and the seq_id is not in the cell
|
||||
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos pos_get(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return pos[i];
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
llama_pos get_shift(uint32_t i) const {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
return shift[i];
|
||||
}
|
||||
|
||||
// check if a cell is not empty and its position is within [p0, p1)
|
||||
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
||||
assert(i < pos.size());
|
||||
|
||||
return pos[i] >= p0 && pos[i] < p1;
|
||||
}
|
||||
|
||||
// set the position of an empty cell
|
||||
// does not modify "has_shift"
|
||||
// note: call only if the cell is empty
|
||||
void pos_set(uint32_t i, llama_pos p) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] == -1);
|
||||
|
||||
pos[i] = p;
|
||||
used++;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] + d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
bool pos_add(uint32_t i, llama_pos d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
pos[i] += d;
|
||||
shift[i] += d;
|
||||
|
||||
has_shift = true;
|
||||
|
||||
if (pos[i] < 0) {
|
||||
pos[i] = -1;
|
||||
seq[i].reset();
|
||||
|
||||
used--;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// pos[i] = pos[i] / d
|
||||
// sets "has_shift" to true
|
||||
// note: call only if the cell is not empty
|
||||
void pos_div(uint32_t i, int d) {
|
||||
assert(i < pos.size());
|
||||
assert(pos[i] != -1);
|
||||
|
||||
const llama_pos p_old = pos[i];
|
||||
|
||||
pos[i] /= d;
|
||||
shift[i] += p_old - pos[i];
|
||||
|
||||
has_shift = true;
|
||||
}
|
||||
|
||||
private:
|
||||
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||
|
||||
bool has_shift = false;
|
||||
|
||||
std::vector<llama_pos> pos;
|
||||
|
||||
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||
//
|
||||
// cells.pos_add(x, shift_x);
|
||||
// cells.pos_div(y, shift_y);
|
||||
// ...
|
||||
//
|
||||
// if (cells.has_shift()) {
|
||||
// for (int i = 0; i < n; ++i) {
|
||||
// auto shift_i = cells.get_shift(i);
|
||||
// ...
|
||||
// }
|
||||
// cells.reset_shift();
|
||||
// }
|
||||
//
|
||||
std::vector<llama_pos> shift;
|
||||
|
||||
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
|
||||
};
|
||||
|
||||
+1
-1
@@ -22,7 +22,7 @@ public:
|
||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||
|
||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||
|
||||
+18
-10
@@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
GGML_ASSERT(hparams.n_expert_used == 0);
|
||||
}
|
||||
|
||||
// zero-out the array hparams
|
||||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
||||
|
||||
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
||||
|
||||
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
|
||||
|
||||
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
|
||||
|
||||
@@ -574,7 +577,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||
hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
|
||||
switch (hparams.n_expert) {
|
||||
case 16: type = LLM_TYPE_17B_16E; break;
|
||||
@@ -863,7 +866,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
hparams.n_swa = 0;
|
||||
hparams.n_swa_pattern = 1;
|
||||
hparams.set_swa_pattern(1);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
@@ -935,7 +938,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa = 4096; // default value of gemma 2
|
||||
hparams.n_swa_pattern = 2;
|
||||
hparams.set_swa_pattern(2);
|
||||
hparams.attn_soft_cap = true;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
@@ -953,7 +956,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa_pattern = 6;
|
||||
hparams.set_swa_pattern(6);
|
||||
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
@@ -1038,7 +1041,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
case LLM_ARCH_COHERE2:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.n_swa_pattern = 4;
|
||||
hparams.set_swa_pattern(4);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
@@ -2486,7 +2489,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -4320,7 +4327,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
|
||||
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
|
||||
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
|
||||
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
|
||||
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
|
||||
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
|
||||
@@ -13189,6 +13196,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
@@ -13215,7 +13223,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
GGML_ASSERT(hparams.n_swa_pattern != 1);
|
||||
GGML_ASSERT(hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified_iswa(
|
||||
*this,
|
||||
@@ -13229,7 +13237,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
cparams.n_batch,
|
||||
padding);
|
||||
} else {
|
||||
GGML_ASSERT(hparams.n_swa_pattern == 1);
|
||||
GGML_ASSERT(!hparams.is_swa_any());
|
||||
|
||||
res = new llama_kv_cache_unified(
|
||||
*this,
|
||||
|
||||
+4
-4
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
|
||||
}
|
||||
|
||||
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
|
||||
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
|
||||
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
|
||||
// at the beginning tokenization score is zero
|
||||
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
|
||||
|
||||
@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
|
||||
const double challenger_score = current_best.score_sum + token_score;
|
||||
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
||||
if (challenger_score > current_champ.score_sum) {
|
||||
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
|
||||
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
|
||||
current_champ = challenger;
|
||||
}
|
||||
}
|
||||
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
|
||||
prefix_offset = input_offset + n_utf8_code_units;
|
||||
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
||||
if (challenger_score > current_champ.score_sum) {
|
||||
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
|
||||
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
|
||||
current_champ = challenger;
|
||||
}
|
||||
}
|
||||
@@ -1007,7 +1007,7 @@ private:
|
||||
struct best_tokenization {
|
||||
llama_token token_id;
|
||||
size_t input_offset;
|
||||
float score_sum;
|
||||
double score_sum;
|
||||
};
|
||||
|
||||
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
||||
|
||||
@@ -92,6 +92,7 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2 ARGS ${CMAKE
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-nomic-bert-moe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-nomic-bert-moe.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-phi-3.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
|
||||
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
|
||||
@@ -142,8 +143,10 @@ if (NOT WIN32)
|
||||
# llama_build_and_test(test-double-float.cpp) # SLOW
|
||||
endif()
|
||||
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(test-chat-parser.cpp)
|
||||
llama_build_and_test(test-chat-template.cpp)
|
||||
llama_build_and_test(test-json-partial.cpp)
|
||||
llama_build_and_test(test-log.cpp)
|
||||
llama_build_and_test(test-regex-partial.cpp)
|
||||
|
||||
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
||||
//
|
||||
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
||||
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
||||
//
|
||||
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||
//
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
|
||||
#include "chat-parser.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
static void assert_equals(const char * expected, const std::string & actual) {
|
||||
return assert_equals<std::string>(expected, actual);
|
||||
}
|
||||
|
||||
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
|
||||
try {
|
||||
fn();
|
||||
} catch (const std::exception & e) {
|
||||
if (expected_exception_pattern.empty()) {
|
||||
return;
|
||||
}
|
||||
std::regex expected_exception_regex(expected_exception_pattern);
|
||||
std::string actual_message = e.what();
|
||||
if (std::regex_search(actual_message, expected_exception_regex)) {
|
||||
return;
|
||||
}
|
||||
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
|
||||
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
|
||||
}
|
||||
throw std::runtime_error("Exception was expected but not thrown");
|
||||
}
|
||||
|
||||
static void test_reasoning() {
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<think>Cogito</think>", builder.result().content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static void test_regex() {
|
||||
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
|
||||
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
|
||||
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
|
||||
};
|
||||
|
||||
test_throws("Hello, world!", "abc", "^abc$");
|
||||
test_throws("Hello, world!", "e", "^e$");
|
||||
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
builder.consume_regex(common_regex("Hello"));
|
||||
assert_equals(", world!", builder.consume_rest());
|
||||
}
|
||||
|
||||
{
|
||||
// When in non partial mode, we can say whether the regex was consumed or not.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
|
||||
assert_equals(true, res.has_value());
|
||||
// Verify captures
|
||||
assert_equals<size_t>(2, res->groups.size());
|
||||
assert_equals("Hell", builder.str(res->groups[0]));
|
||||
assert_equals("el", builder.str(res->groups[1]));
|
||||
// Verify position is after the match
|
||||
assert_equals<size_t>(4, builder.pos());
|
||||
assert_equals("o,", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
|
||||
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
|
||||
assert_throws([&]() {
|
||||
builder.try_consume_regex(common_regex("Hello, world!"));
|
||||
}, "^Hello, world!$");
|
||||
}
|
||||
|
||||
// Now regardless of the mode, we can tell these aren't a match.
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
|
||||
}
|
||||
for (const auto is_partial : {false, true}) {
|
||||
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||
assert_equals(false, builder.try_consume_literal("Oh"));
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::string> barely_healable_jsons = {
|
||||
"{",
|
||||
"{\"",
|
||||
"{\"\\",
|
||||
"{\"n",
|
||||
"{\"name\"",
|
||||
"{\"name\":",
|
||||
"{\"name\":\"",
|
||||
"{\"name\":\"\\",
|
||||
"{\"name\":\"python",
|
||||
"{\"name\":\"python\\",
|
||||
"{\",",
|
||||
"{\":",
|
||||
"{\"[",
|
||||
"{\"]",
|
||||
"{\"{",
|
||||
"{\"}",
|
||||
"{\"1",
|
||||
"{\"name\":\",",
|
||||
"{\"name\":\":",
|
||||
"{\"name\":\"[",
|
||||
"{\"name\":\"]",
|
||||
"{\"name\":\"{",
|
||||
"{\"name\":\"}",
|
||||
"{\"name\":\"1",
|
||||
};
|
||||
|
||||
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
|
||||
common_chat_msg_parser builder(input, is_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
|
||||
}
|
||||
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
|
||||
common_chat_msg_parser builder(input, parse_as_partial, {});
|
||||
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
|
||||
assert_equals(true, js.has_value());
|
||||
assert_equals(is_partial, js->is_partial);
|
||||
assert_equals(expected, js->value.dump());
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args_no_args() {
|
||||
// Normal JSON, nothing to heal, nothing to dump
|
||||
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
|
||||
// Full json is args
|
||||
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
|
||||
|
||||
// If the arguments are further down, don't heal partial content.
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{"arguments"}}, {}, "{}");
|
||||
}
|
||||
// But heal content that isn't partial.
|
||||
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
|
||||
}
|
||||
|
||||
static void test_json_with_dumped_args() {
|
||||
|
||||
// Partial content.
|
||||
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
|
||||
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
|
||||
test("{\"content\": ", true, {}, {{"content"}}, "{}");
|
||||
|
||||
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
|
||||
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
|
||||
for (const auto & src : barely_healable_jsons) {
|
||||
test(src, true, {{}}, {}, src);
|
||||
}
|
||||
|
||||
// Full JSON w/ args
|
||||
for (auto parse_as_partial : {true, false}) {
|
||||
test_with_args(
|
||||
R"({"name": "python", "args": {"arg1": 1}})",
|
||||
R"({"name":"python","args":"{\"arg1\":1}"})",
|
||||
parse_as_partial,
|
||||
/* is_partial= */ false
|
||||
);
|
||||
}
|
||||
|
||||
// Partial JSON w/ partial args
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {")",
|
||||
R"({"foo":"bar","args":"{\""})"
|
||||
);
|
||||
// Partial args broken in object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"ar)",
|
||||
R"({"foo":"bar","args":"{\"ar"})"
|
||||
);
|
||||
// Partial args broken after object key
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\""})"
|
||||
);
|
||||
// Partial args broken before object value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1":)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken before object value (space)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that may not be complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": 1 )",
|
||||
R"({"foo":"bar","args":"{\"arg1\":1"})"
|
||||
);
|
||||
// Partial args broken in object value that is incomplete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": ")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\""})"
|
||||
);
|
||||
// Partial args broken in object value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": "1")",
|
||||
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
|
||||
);
|
||||
// Partial args broken on array opening
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is incomplete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1)",
|
||||
R"({"foo":"bar","args":"["})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (int)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1 )",
|
||||
R"({"foo":"bar","args":"[1"})"
|
||||
);
|
||||
// Partial args broken on array value that is complete (string)
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": ["1")",
|
||||
R"({"foo":"bar","args":"[\"1\""})"
|
||||
);
|
||||
// Partial args broken after array value
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": [1,)",
|
||||
R"({"foo":"bar","args":"[1,"})"
|
||||
);
|
||||
// Partial args broken on nested array
|
||||
test_with_args(
|
||||
R"({"foo": "bar", "args": {"arg1": [)",
|
||||
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_positions() {
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_to(100); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
assert_throws([&]() { builder.move_back(1); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(8);
|
||||
assert_equals<size_t>(8, builder.pos());
|
||||
builder.move_back(1);
|
||||
assert_equals<size_t>(7, builder.pos());
|
||||
assert_equals("world!", builder.consume_rest());
|
||||
|
||||
builder.move_to(0);
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
assert_throws([&]() { builder.finish(); });
|
||||
assert_equals<size_t>(0, builder.pos());
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
builder.finish();
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
|
||||
|
||||
builder.move_to(builder.input().size());
|
||||
assert_equals<size_t>(builder.input().size(), builder.pos());
|
||||
builder.finish();
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_positions();
|
||||
test_json_with_dumped_args_no_args();
|
||||
test_json_with_dumped_args();
|
||||
test_reasoning();
|
||||
test_regex();
|
||||
std::cout << "All tests passed!\n";
|
||||
return 0;
|
||||
}
|
||||
+730
-277
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,237 @@
|
||||
#include "common.h"
|
||||
#include "json-partial.h"
|
||||
#include <exception>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
|
||||
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
std::cerr << std::flush;
|
||||
throw std::runtime_error("Test failed");
|
||||
}
|
||||
}
|
||||
|
||||
static void test_json_healing() {
|
||||
auto parse = [](const std::string & str) {
|
||||
std::cerr << "# Parsing: " << str << '\n';
|
||||
std::string::const_iterator it = str.begin();
|
||||
const auto end = str.end();
|
||||
common_json out;
|
||||
std::string healing_marker = "$llama.cpp.json$";
|
||||
if (common_json_parse(it, end, healing_marker, out)) {
|
||||
auto dump = out.json.dump();
|
||||
std::cerr << "Parsed: " << dump << '\n';
|
||||
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
|
||||
std::string result;
|
||||
if (!out.healing_marker.json_dump_marker.empty()) {
|
||||
auto i = dump.find(out.healing_marker.json_dump_marker);
|
||||
if (i == std::string::npos) {
|
||||
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
|
||||
}
|
||||
result = dump.substr(0, i);
|
||||
} else {
|
||||
result = dump;
|
||||
}
|
||||
std::cerr << "Result: " << result << '\n';
|
||||
if (string_starts_with(str, result)) {
|
||||
std::cerr << "Failure!\n";
|
||||
}
|
||||
// return dump;
|
||||
} else {
|
||||
throw std::runtime_error("Failed to parse: " + str);
|
||||
}
|
||||
|
||||
};
|
||||
auto parse_all = [&](const std::string & str) {
|
||||
for (size_t i = 1; i < str.size(); i++) {
|
||||
parse(str.substr(0, i));
|
||||
}
|
||||
};
|
||||
parse_all("{\"a\": \"b\"}");
|
||||
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
|
||||
|
||||
parse_all("[{\"a\": \"b\"}]");
|
||||
|
||||
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
|
||||
for (const auto & input : inputs) {
|
||||
common_json out;
|
||||
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||
assert_equals<std::string>(expected, out.json.dump());
|
||||
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||
}
|
||||
};
|
||||
// No healing needed:
|
||||
test(
|
||||
{
|
||||
R"([{"a":"b"}, "y"])",
|
||||
},
|
||||
R"([{"a":"b"},"y"])",
|
||||
""
|
||||
);
|
||||
// Partial literals can't be healed:
|
||||
test(
|
||||
{
|
||||
R"([1)",
|
||||
R"([tru)",
|
||||
R"([n)",
|
||||
R"([nul)",
|
||||
R"([23.2)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({"a": 1)",
|
||||
R"({"a": tru)",
|
||||
R"({"a": n)",
|
||||
R"({"a": nul)",
|
||||
R"({"a": 23.2)",
|
||||
},
|
||||
R"({"a":"$foo"})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({)",
|
||||
},
|
||||
R"({"$foo":1})",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([)",
|
||||
},
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Healing right after a full literal
|
||||
test(
|
||||
{
|
||||
R"(1 )",
|
||||
},
|
||||
R"(1)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(true)",
|
||||
R"(true )",
|
||||
},
|
||||
R"(true)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"(null)",
|
||||
R"(null )",
|
||||
},
|
||||
R"(null)",
|
||||
""
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([1 )",
|
||||
},
|
||||
R"([1,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{})",
|
||||
R"([{} )",
|
||||
},
|
||||
R"([{},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true)",
|
||||
},
|
||||
// TODO: detect the true/false/null literal was complete
|
||||
R"(["$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true )",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([true,)",
|
||||
},
|
||||
R"([true,"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
// Test nesting
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [{)",
|
||||
},
|
||||
R"([{"a":[{"b":[{"$foo":1}]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": [{"b": [)",
|
||||
},
|
||||
R"([{"a":[{"b":["$foo"]}]}])",
|
||||
R"("$foo)"
|
||||
);
|
||||
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"})",
|
||||
R"([{"a": "b"} )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"(,"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"([{"a": "b"},)",
|
||||
R"([{"a": "b"}, )",
|
||||
},
|
||||
R"([{"a":"b"},"$foo"])",
|
||||
R"("$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code)",
|
||||
},
|
||||
R"({"code$foo":1})",
|
||||
R"($foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code\)",
|
||||
},
|
||||
R"({"code\\$foo":1})",
|
||||
R"(\$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "code")",
|
||||
},
|
||||
R"({"code":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
test(
|
||||
{
|
||||
R"({ "key")",
|
||||
},
|
||||
R"({"key":"$foo"})",
|
||||
R"(:"$foo)"
|
||||
);
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_json_healing();
|
||||
std::cerr << "All tests passed.\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -1,5 +1,15 @@
|
||||
# mtmd
|
||||
|
||||
# compile mtmd-audio separately to avoid long compile times with miniaudio.h
|
||||
# TODO @ngxson : move miniaudio.h and stb_image.h to mtmd-helper.cpp, then compile the helper as a separate library
|
||||
add_library(mtmd_audio STATIC mtmd-audio.cpp mtmd-audio.h)
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(mtmd_audio PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
target_link_libraries(mtmd_audio PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(mtmd_audio PRIVATE cxx_std_17)
|
||||
target_include_directories(mtmd_audio PRIVATE .)
|
||||
|
||||
add_library(mtmd OBJECT
|
||||
mtmd.cpp
|
||||
mtmd-helper.cpp
|
||||
@@ -9,7 +19,7 @@ add_library(mtmd OBJECT
|
||||
clip-impl.h
|
||||
)
|
||||
|
||||
target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(mtmd PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
target_include_directories(mtmd PUBLIC .)
|
||||
target_include_directories(mtmd PRIVATE ../..)
|
||||
@@ -22,12 +32,13 @@ if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||
add_library(mtmd_shared SHARED $<TARGET_OBJECTS:mtmd>)
|
||||
target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(mtmd_shared PRIVATE ggml llama mtmd_audio ${CMAKE_THREAD_LIBS_INIT})
|
||||
install(TARGETS mtmd_shared LIBRARY)
|
||||
endif()
|
||||
|
||||
if (NOT MSVC)
|
||||
target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h
|
||||
target_compile_options(mtmd_audio PRIVATE -Wno-cast-qual) # miniaudio.h
|
||||
endif()
|
||||
|
||||
if(TARGET BUILD_INFO)
|
||||
|
||||
+40
-12
@@ -16,22 +16,26 @@
|
||||
#define KEY_FTYPE "general.file_type"
|
||||
#define KEY_NAME "general.name"
|
||||
#define KEY_DESCRIPTION "general.description"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_HAS_AUDIO_ENC "clip.has_audio_encoder"
|
||||
#define KEY_HAS_VISION_ENC "clip.has_vision_encoder"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_USE_SILU "clip.use_silu"
|
||||
#define KEY_N_EMBD "clip.vision.embedding_length"
|
||||
#define KEY_N_FF "clip.vision.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.vision.block_count"
|
||||
#define KEY_N_HEAD "clip.vision.attention.head_count"
|
||||
#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon"
|
||||
#define KEY_PROJ_DIM "clip.vision.projection_dim"
|
||||
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
#define KEY_N_BLOCK "clip.%s.block_count"
|
||||
#define KEY_PROJ_DIM "clip.%s.projection_dim"
|
||||
#define KEY_N_HEAD "clip.%s.attention.head_count"
|
||||
#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon"
|
||||
|
||||
// vision-specific
|
||||
#define KEY_IMAGE_SIZE "clip.vision.image_size"
|
||||
#define KEY_PATCH_SIZE "clip.vision.patch_size"
|
||||
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
|
||||
#define KEY_IMAGE_STD "clip.vision.image_std"
|
||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||
|
||||
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
|
||||
@@ -39,13 +43,18 @@
|
||||
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
|
||||
#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern"
|
||||
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
|
||||
// audio-specific
|
||||
#define KEY_A_NUM_MEL_BINS "clip.audio.num_mel_bins"
|
||||
#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
|
||||
|
||||
|
||||
//
|
||||
// tensor name constants
|
||||
//
|
||||
|
||||
#define TN_POS_EMBD "v.position_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
|
||||
@@ -95,6 +104,13 @@
|
||||
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
|
||||
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
|
||||
|
||||
// ultravox
|
||||
#define TN_CONV1D "a.conv1d.%d.%s"
|
||||
#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
|
||||
#define TN_MM_AUDIO_FC "mm.a.fc.%s" // fully connected layer
|
||||
#define TN_MM_NORM_PRE "mm.a.norm_pre.%s"
|
||||
#define TN_MM_NORM_MID "mm.a.norm_mid.%s"
|
||||
|
||||
// align x to upper multiple of n
|
||||
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||
|
||||
@@ -110,8 +126,10 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
PROJECTOR_TYPE_QWEN25VL,
|
||||
PROJECTOR_TYPE_ULTRAVOX,
|
||||
PROJECTOR_TYPE_INTERNVL,
|
||||
PROJECTOR_TYPE_LLAMA4,
|
||||
PROJECTOR_TYPE_QWEN2A,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -126,8 +144,10 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
|
||||
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
|
||||
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
|
||||
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
@@ -147,8 +167,10 @@ struct clip_image_u8 {
|
||||
std::vector<uint8_t> buf;
|
||||
};
|
||||
|
||||
// RGB float32 image (NHWC)
|
||||
// Memory layout: RGBRGBRGB...
|
||||
// For images, buf.size() == nx*ny*3
|
||||
// Memory layout: RGBRGBRGB...
|
||||
// For audio, only one channel is used, buf.size() == nx*ny
|
||||
// nx will be n_frames and ny will be n_mel
|
||||
struct clip_image_f32 {
|
||||
int nx;
|
||||
int ny;
|
||||
@@ -242,6 +264,7 @@ struct clip_image_u8_batch {
|
||||
|
||||
struct clip_image_f32_batch {
|
||||
std::vector<clip_image_f32_ptr> entries;
|
||||
bool is_audio = false;
|
||||
|
||||
// for llava-uhd style models, we need to know the grid size
|
||||
// note: entries.size() == grid_x * grid_y + 1 (one overview image)
|
||||
@@ -249,7 +272,12 @@ struct clip_image_f32_batch {
|
||||
int grid_y = 0;
|
||||
|
||||
clip_image_f32_batch clone() const {
|
||||
clip_image_f32_batch new_batch;
|
||||
clip_image_f32_batch new_batch{
|
||||
/* entries */ {},
|
||||
/* is_audio */ is_audio,
|
||||
/* grid_x */ grid_x,
|
||||
/* grid_y */ grid_y,
|
||||
};
|
||||
new_batch.entries.reserve(entries.size());
|
||||
for (const auto & entry : entries) {
|
||||
new_batch.entries.emplace_back(new clip_image_f32(*entry));
|
||||
|
||||
+309
-59
@@ -35,6 +35,7 @@ struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callbac
|
||||
|
||||
enum ffn_op_type {
|
||||
FFN_GELU,
|
||||
FFN_GELU_ERF,
|
||||
FFN_SILU,
|
||||
FFN_GELU_QUICK,
|
||||
};
|
||||
@@ -165,6 +166,9 @@ enum patch_merge_type {
|
||||
};
|
||||
|
||||
struct clip_hparams {
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
|
||||
int32_t image_size;
|
||||
int32_t patch_size;
|
||||
int32_t n_embd;
|
||||
@@ -191,6 +195,10 @@ struct clip_hparams {
|
||||
int32_t attn_window_size = 0;
|
||||
int32_t n_wa_pattern = 0;
|
||||
int32_t spatial_merge_size = 0;
|
||||
|
||||
// audio
|
||||
int32_t n_mel_bins = 0; // whisper preprocessor
|
||||
int32_t proj_stack_factor = 0; // ultravox
|
||||
};
|
||||
|
||||
struct clip_layer {
|
||||
@@ -246,7 +254,9 @@ struct clip_vision_model {
|
||||
ggml_tensor * post_ln_w;
|
||||
ggml_tensor * post_ln_b;
|
||||
|
||||
ggml_tensor * projection;
|
||||
ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
|
||||
ggml_tensor * mm_fc_w;
|
||||
ggml_tensor * mm_fc_b;
|
||||
|
||||
// LLaVA projection
|
||||
ggml_tensor * mm_input_norm_w = nullptr;
|
||||
@@ -332,6 +342,14 @@ struct clip_vision_model {
|
||||
// pixtral
|
||||
ggml_tensor * token_embd_img_break = nullptr;
|
||||
ggml_tensor * mm_patch_merger_w = nullptr;
|
||||
|
||||
// ultravox / whisper encoder
|
||||
ggml_tensor * conv1d_1_w = nullptr;
|
||||
ggml_tensor * conv1d_1_b = nullptr;
|
||||
ggml_tensor * conv1d_2_w = nullptr;
|
||||
ggml_tensor * conv1d_2_b = nullptr;
|
||||
ggml_tensor * mm_norm_pre_w = nullptr;
|
||||
ggml_tensor * mm_norm_mid_w = nullptr;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
@@ -1408,6 +1426,114 @@ struct clip_graph {
|
||||
return gf;
|
||||
}
|
||||
|
||||
// whisper encoder with custom projector
|
||||
ggml_cgraph * build_whisper_enc() {
|
||||
const int n_frames = img.nx;
|
||||
const int n_pos = n_frames / 2;
|
||||
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
|
||||
|
||||
ggml_tensor * inp = build_inp_raw(1);
|
||||
|
||||
// conv1d block
|
||||
{
|
||||
// convolution + gelu
|
||||
ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
|
||||
cur = ggml_add(ctx0, cur, model.conv1d_1_b);
|
||||
|
||||
cur = ggml_gelu_erf(ctx0, cur);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
|
||||
cur = ggml_add(ctx0, cur, model.conv1d_2_b);
|
||||
|
||||
cur = ggml_gelu_erf(ctx0, cur);
|
||||
// transpose
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
|
||||
cb(inp, "after_conv1d", -1);
|
||||
}
|
||||
|
||||
// sanity check (only check one layer, but it should be the same for all)
|
||||
GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
|
||||
GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
|
||||
GGML_ASSERT(model.layers[0].q_b);
|
||||
GGML_ASSERT(model.layers[0].v_b);
|
||||
GGML_ASSERT(!model.layers[0].k_b); // no bias for k
|
||||
GGML_ASSERT(model.post_ln_w && model.post_ln_b);
|
||||
|
||||
ggml_tensor * pos_embd_selected = ggml_view_2d(
|
||||
ctx0, model.position_embeddings,
|
||||
model.position_embeddings->ne[0], n_pos,
|
||||
model.position_embeddings->nb[1], 0
|
||||
);
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_pos,
|
||||
NORM_TYPE_NORMAL,
|
||||
hparams.ffn_op,
|
||||
pos_embd_selected,
|
||||
nullptr);
|
||||
|
||||
cb(cur, "after_transformer", -1);
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
// StackAudioFrames
|
||||
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
|
||||
{
|
||||
int64_t stride = n_embd * hparams.proj_stack_factor;
|
||||
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
|
||||
int64_t pad = padded_len - ggml_nelements(cur);
|
||||
if (pad > 0) {
|
||||
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
|
||||
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
|
||||
}
|
||||
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
|
||||
ggml_row_size(cur->type, stride), 0);
|
||||
}
|
||||
|
||||
cb(cur, "after_stacked", -1);
|
||||
|
||||
// UltravoxProjector
|
||||
{
|
||||
// pre-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
|
||||
|
||||
// ffn in
|
||||
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
|
||||
|
||||
// swiglu
|
||||
{
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
|
||||
x1 = ggml_silu(ctx0, x1);
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
}
|
||||
|
||||
// mid-norm
|
||||
cur = ggml_rms_norm(ctx0, cur, 1e-6);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
|
||||
|
||||
// ffn out
|
||||
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
|
||||
}
|
||||
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// projector
|
||||
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.mm_fc_b);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("%s: unknown projector type", __func__);
|
||||
}
|
||||
|
||||
cb(cur, "projected", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
private:
|
||||
//
|
||||
// utility functions
|
||||
@@ -1541,6 +1667,17 @@ private:
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// TODO @ngxson : find a way to move this outside
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
ggml_tensor * cur = inpL;
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
|
||||
cur = ggml_transpose(ctx0, cur);
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (model.post_ln_w) {
|
||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
||||
@@ -1562,8 +1699,8 @@ private:
|
||||
return inp;
|
||||
}
|
||||
|
||||
ggml_tensor * build_inp_raw() {
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
|
||||
ggml_tensor * build_inp_raw(int channels = 3) {
|
||||
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
return inp_raw;
|
||||
@@ -1641,6 +1778,11 @@ private:
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_gelu", il);
|
||||
} break;
|
||||
case FFN_GELU_ERF:
|
||||
{
|
||||
cur = ggml_gelu_erf(ctx0, cur);
|
||||
cb(cur, "ggml_gelu_erf", il);
|
||||
} break;
|
||||
case FFN_GELU_QUICK:
|
||||
{
|
||||
cur = ggml_gelu_quick(ctx0, cur);
|
||||
@@ -1832,6 +1974,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
res = graph.build_llama4();
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
res = graph.build_whisper_enc();
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
res = graph.build_llava();
|
||||
@@ -1915,18 +2062,30 @@ struct clip_model_loader {
|
||||
|
||||
// other hparams
|
||||
{
|
||||
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
|
||||
get_bool(KEY_HAS_AUDIO_ENC, hparams.has_audio, false);
|
||||
get_bool(KEY_HAS_VISION_ENC, hparams.has_vision, false);
|
||||
|
||||
get_u32(KEY_N_EMBD, hparams.n_embd);
|
||||
get_u32(KEY_N_HEAD, hparams.n_head);
|
||||
get_u32(KEY_N_FF, hparams.n_ff);
|
||||
get_u32(KEY_N_BLOCK, hparams.n_layer);
|
||||
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
|
||||
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
|
||||
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
|
||||
const char * prefix = hparams.has_vision ? "vision" : "audio";
|
||||
get_u32(string_format(KEY_N_EMBD, prefix), hparams.n_embd);
|
||||
get_u32(string_format(KEY_N_HEAD, prefix), hparams.n_head);
|
||||
get_u32(string_format(KEY_N_FF, prefix), hparams.n_ff);
|
||||
get_u32(string_format(KEY_N_BLOCK, prefix), hparams.n_layer);
|
||||
get_u32(string_format(KEY_PROJ_DIM, prefix), hparams.projection_dim);
|
||||
get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
|
||||
|
||||
if (hparams.has_vision) {
|
||||
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
|
||||
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
|
||||
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
|
||||
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
|
||||
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
|
||||
|
||||
} else if (hparams.has_audio) {
|
||||
get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
|
||||
|
||||
} else {
|
||||
throw std::runtime_error(string_format("%s: neither vision nor audio encoder is present\n", __func__));
|
||||
}
|
||||
|
||||
// default warmup value
|
||||
hparams.warmup_image_size = hparams.image_size;
|
||||
@@ -1964,7 +2123,7 @@ struct clip_model_loader {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (hparams.has_vision) {
|
||||
int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
|
||||
int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
|
||||
GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
|
||||
@@ -2050,30 +2209,45 @@ struct clip_model_loader {
|
||||
isize, isize*3, // 336, 1008
|
||||
};
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
bool require_stack = ctx_clip.proj_type == PROJECTOR_TYPE_ULTRAVOX;
|
||||
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
|
||||
if (hparams.n_mel_bins != 128) {
|
||||
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
|
||||
}
|
||||
hparams.ffn_op = FFN_GELU_ERF;
|
||||
log_ffn_op = "gelu_erf"; // temporary solution for logging
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
|
||||
LOG_INF("%s: has_vision_encoder: %d\n", __func__, hparams.has_vision);
|
||||
LOG_INF("%s: has_audio_encoder: %d\n", __func__, hparams.has_audio);
|
||||
LOG_INF("%s: n_embd: %d\n", __func__, hparams.n_embd);
|
||||
LOG_INF("%s: n_head: %d\n", __func__, hparams.n_head);
|
||||
LOG_INF("%s: n_ff: %d\n", __func__, hparams.n_ff);
|
||||
LOG_INF("%s: n_layer: %d\n", __func__, hparams.n_layer);
|
||||
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
|
||||
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
|
||||
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
|
||||
LOG_INF("%s: projection_dim: %d\n", __func__, hparams.projection_dim);
|
||||
LOG_INF("\n");
|
||||
if (hparams.has_vision) {
|
||||
LOG_INF("%s: image_size: %d\n", __func__, hparams.image_size);
|
||||
LOG_INF("%s: patch_size: %d\n", __func__, hparams.patch_size);
|
||||
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
|
||||
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
|
||||
LOG_INF("%s: proj_scale_factor: %d\n", __func__, hparams.proj_scale_factor);
|
||||
LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern);
|
||||
} else if (hparams.has_audio) {
|
||||
LOG_INF("%s: n_mel_bins: %d\n", __func__, hparams.n_mel_bins);
|
||||
LOG_INF("%s: proj_stack_factor: %d\n", __func__, hparams.proj_stack_factor);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
|
||||
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
|
||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2082,6 +2256,9 @@ struct clip_model_loader {
|
||||
std::map<std::string, size_t> tensor_offset;
|
||||
std::vector<ggml_tensor *> tensors_to_load;
|
||||
|
||||
// TODO @ngxson : support both audio and video in the future
|
||||
const char * prefix = hparams.has_audio ? "a" : "v";
|
||||
|
||||
// get offsets
|
||||
for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
|
||||
const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
|
||||
@@ -2115,51 +2292,51 @@ struct clip_model_loader {
|
||||
return cur;
|
||||
};
|
||||
|
||||
auto & vision_model = ctx_clip.vision_model;
|
||||
auto & vision_model = ctx_clip.vision_model; // TODO: rename this to just "model"
|
||||
|
||||
vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
|
||||
|
||||
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false);
|
||||
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false);
|
||||
vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
|
||||
vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"), false);
|
||||
|
||||
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false);
|
||||
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false);
|
||||
vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
|
||||
vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"), false);
|
||||
|
||||
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
|
||||
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
|
||||
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
|
||||
|
||||
vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
|
||||
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
|
||||
|
||||
// layers
|
||||
vision_model.layers.resize(hparams.n_layer);
|
||||
for (int il = 0; il < hparams.n_layer; ++il) {
|
||||
auto & layer = vision_model.layers[il];
|
||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight"));
|
||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight"));
|
||||
layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight"));
|
||||
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
|
||||
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, "v", il, "weight"), false);
|
||||
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, "v", il, "weight"), false);
|
||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
|
||||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
||||
layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias
|
||||
layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias
|
||||
layer.k_w = get_tensor(string_format(TN_ATTN_K, prefix, il, "weight"));
|
||||
layer.q_w = get_tensor(string_format(TN_ATTN_Q, prefix, il, "weight"));
|
||||
layer.v_w = get_tensor(string_format(TN_ATTN_V, prefix, il, "weight"));
|
||||
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
|
||||
layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
|
||||
layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
|
||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, prefix, il, "weight"), false);
|
||||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, prefix, il, "weight"), false);
|
||||
layer.ls_1_w = get_tensor(string_format(TN_LS_1, prefix, il, "weight"), false); // no bias
|
||||
layer.ls_2_w = get_tensor(string_format(TN_LS_2, prefix, il, "weight"), false); // no bias
|
||||
|
||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, prefix, il, "bias"), false);
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, prefix, il, "bias"), false);
|
||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, prefix, il, "bias"), false);
|
||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, prefix, il, "bias"), false);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, prefix, il, "bias"), false);
|
||||
|
||||
// ffn
|
||||
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
||||
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
||||
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
|
||||
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
|
||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
||||
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, prefix, il, "weight"));
|
||||
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, prefix, il, "bias"), false);
|
||||
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
|
||||
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"), false);
|
||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
|
||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"), false);
|
||||
|
||||
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
|
||||
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
|
||||
@@ -2301,6 +2478,26 @@ struct clip_model_loader {
|
||||
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
vision_model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
|
||||
vision_model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
|
||||
vision_model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
|
||||
vision_model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
{
|
||||
vision_model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
vision_model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
|
||||
vision_model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
|
||||
vision_model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
|
||||
vision_model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
|
||||
vision_model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
{
|
||||
vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
@@ -2358,13 +2555,19 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
void alloc_compute_meta() {
|
||||
const auto & hparams = ctx_clip.vision_model.hparams;
|
||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
|
||||
// create a fake batch
|
||||
clip_image_f32_batch batch;
|
||||
clip_image_f32_ptr img(clip_image_f32_init());
|
||||
img->nx = ctx_clip.vision_model.hparams.warmup_image_size;
|
||||
img->ny = ctx_clip.vision_model.hparams.warmup_image_size;
|
||||
if (hparams.has_vision) {
|
||||
img->nx = hparams.warmup_image_size;
|
||||
img->ny = hparams.warmup_image_size;
|
||||
} else {
|
||||
img->nx = 1024; // TODO @ngxson : use a better default
|
||||
img->ny = hparams.n_mel_bins;
|
||||
}
|
||||
img->buf.resize(img->nx * img->ny * 3);
|
||||
batch.entries.push_back(std::move(img));
|
||||
|
||||
@@ -3278,6 +3481,14 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
|
||||
n_patches /= (scale_factor * scale_factor);
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX) {
|
||||
const int proj_stack_factor = ctx->vision_model.hparams.proj_stack_factor;
|
||||
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
|
||||
n_patches = n_len / proj_stack_factor / 2;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// divide by 2 because of whisper
|
||||
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
|
||||
n_patches = img->nx / 4;
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
@@ -3435,7 +3646,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
};
|
||||
|
||||
// set input pixel values
|
||||
{
|
||||
if (!imgs.is_audio) {
|
||||
size_t nelem = 0;
|
||||
for (const auto & img : imgs.entries) {
|
||||
nelem += img->nx * img->ny * 3;
|
||||
@@ -3472,6 +3683,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
}
|
||||
}
|
||||
set_input_f32("inp_raw", inp_raw);
|
||||
|
||||
} else {
|
||||
// audio input
|
||||
GGML_ASSERT(imgs.entries.size() == 1);
|
||||
const auto & mel_inp = imgs.entries[0];
|
||||
const int n_step = mel_inp->nx;
|
||||
const int n_mel = mel_inp->ny;
|
||||
std::vector<float> inp_raw(n_step * n_mel);
|
||||
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
|
||||
set_input_f32("inp_raw", inp_raw);
|
||||
}
|
||||
|
||||
// set input per projector
|
||||
@@ -3668,6 +3889,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
{
|
||||
// do nothing
|
||||
} break;
|
||||
@@ -3727,7 +3950,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
const int n_tokens_out = embeddings->ne[1];
|
||||
const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
|
||||
if (n_tokens_out != expected_n_tokens_out) {
|
||||
LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
||||
LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
|
||||
GGML_ABORT("Invalid number of output tokens");
|
||||
}
|
||||
|
||||
@@ -3766,10 +3989,14 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
return ctx->vision_model.projection->ne[1];
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
return ctx->vision_model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
return ctx->vision_model.mm_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
return ctx->vision_model.mm_model_proj->ne[1];
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
return ctx->vision_model.mm_fc_w->ne[1];
|
||||
default:
|
||||
GGML_ABORT("Unknown projector type");
|
||||
}
|
||||
@@ -3798,6 +4025,18 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
|
||||
}
|
||||
|
||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.has_vision;
|
||||
}
|
||||
|
||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.hparams.has_audio;
|
||||
}
|
||||
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type == PROJECTOR_TYPE_ULTRAVOX || ctx->proj_type == PROJECTOR_TYPE_QWEN2A;
|
||||
}
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
|
||||
clip_image_f32 clip_img;
|
||||
clip_img.buf.resize(h * w * 3);
|
||||
@@ -3818,3 +4057,14 @@ bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img,
|
||||
projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
|
||||
return ctx->proj_type;
|
||||
}
|
||||
|
||||
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
|
||||
clip_image_f32 * audio = new clip_image_f32;
|
||||
audio->nx = n_frames;
|
||||
audio->ny = n_mel;
|
||||
audio->buf.resize(n_frames * n_mel);
|
||||
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
|
||||
|
||||
batch->entries.push_back(clip_image_f32_ptr(audio));
|
||||
batch->is_audio = true;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// !!! Internal header, to be used by mtmd only !!!
|
||||
|
||||
struct clip_ctx;
|
||||
|
||||
struct clip_image_size {
|
||||
@@ -93,3 +95,10 @@ bool clip_is_llava(const struct clip_ctx * ctx);
|
||||
bool clip_is_gemma3(const struct clip_ctx * ctx);
|
||||
|
||||
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
|
||||
|
||||
// use by audio input
|
||||
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
|
||||
|
||||
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
|
||||
bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
|
||||
|
||||
+93468
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,855 @@
|
||||
// fix problem with std::min and std::max
|
||||
#if defined(_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
#include "mtmd-audio.h"
|
||||
|
||||
//#define MTMD_AUDIO_DEBUG
|
||||
|
||||
#define MINIAUDIO_IMPLEMENTATION
|
||||
#ifndef MTMD_AUDIO_DEBUG
|
||||
# define MA_NO_ENCODING
|
||||
#endif
|
||||
#define MA_NO_DEVICE_IO
|
||||
#define MA_NO_RESOURCE_MANAGER
|
||||
#define MA_NO_NODE_GRAPH
|
||||
#define MA_NO_ENGINE
|
||||
#define MA_NO_GENERATION
|
||||
#define MA_API static
|
||||
#include "miniaudio.h"
|
||||
|
||||
#define _USE_MATH_DEFINES // for M_PI
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
||||
// most of the code here is copied from whisper.cpp
|
||||
|
||||
// align x to upper multiple of n
|
||||
#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
|
||||
|
||||
namespace whisper_preprocessor {
|
||||
|
||||
#define SIN_COS_N_COUNT WHISPER_N_FFT
|
||||
namespace {
|
||||
struct whisper_global_cache {
|
||||
// In FFT, we frequently use sine and cosine operations with the same values.
|
||||
// We can use precalculated values to speed up the process.
|
||||
float sin_vals[SIN_COS_N_COUNT];
|
||||
float cos_vals[SIN_COS_N_COUNT];
|
||||
|
||||
// Hann window (Use cosf to eliminate difference)
|
||||
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
|
||||
float hann_window[WHISPER_N_FFT];
|
||||
|
||||
whisper_global_cache() {
|
||||
fill_sin_cos_table();
|
||||
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
|
||||
}
|
||||
|
||||
void fill_sin_cos_table() {
|
||||
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
|
||||
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
}
|
||||
|
||||
void fill_hann_window(int length, bool periodic, float * output) {
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
} global_cache;
|
||||
}
|
||||
|
||||
// naive Discrete Fourier Transform
|
||||
// input is real-valued
|
||||
// output is complex-valued
|
||||
static void dft(const float* in, int N, float* out) {
|
||||
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||
|
||||
for (int k = 0; k < N; k++) {
|
||||
float re = 0;
|
||||
float im = 0;
|
||||
|
||||
for (int n = 0; n < N; n++) {
|
||||
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
|
||||
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
|
||||
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
|
||||
}
|
||||
|
||||
out[k*2 + 0] = re;
|
||||
out[k*2 + 1] = im;
|
||||
}
|
||||
}
|
||||
|
||||
// Cooley-Tukey FFT
|
||||
// poor man's implementation - use something better
|
||||
// input is real-valued
|
||||
// output is complex-valued
|
||||
static void fft(float* in, int N, float* out) {
|
||||
if (N == 1) {
|
||||
out[0] = in[0];
|
||||
out[1] = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const int half_N = N / 2;
|
||||
if (N - half_N*2 == 1) {
|
||||
dft(in, N, out);
|
||||
return;
|
||||
}
|
||||
|
||||
float* even = in + N;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
even[i]= in[2*i];
|
||||
}
|
||||
float* even_fft = out + 2 * N;
|
||||
fft(even, half_N, even_fft);
|
||||
|
||||
float* odd = even;
|
||||
for (int i = 0; i < half_N; ++i) {
|
||||
odd[i] = in[2*i + 1];
|
||||
}
|
||||
float* odd_fft = even_fft + N;
|
||||
fft(odd, half_N, odd_fft);
|
||||
|
||||
const int sin_cos_step = SIN_COS_N_COUNT / N;
|
||||
for (int k = 0; k < half_N; k++) {
|
||||
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
|
||||
float re = global_cache.cos_vals[idx]; // cos(t)
|
||||
float im = -global_cache.sin_vals[idx]; // sin(t)
|
||||
|
||||
float re_odd = odd_fft[2*k + 0];
|
||||
float im_odd = odd_fft[2*k + 1];
|
||||
|
||||
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
|
||||
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
|
||||
|
||||
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
|
||||
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
|
||||
}
|
||||
}
|
||||
|
||||
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
|
||||
int n_samples, int frame_size, int frame_step, int n_threads,
|
||||
const whisper_filters & filters, whisper_mel & mel) {
|
||||
std::vector<float> fft_in(frame_size * 2, 0.0);
|
||||
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
|
||||
|
||||
int n_fft = filters.n_fft;
|
||||
int i = ith;
|
||||
|
||||
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
|
||||
WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
|
||||
|
||||
// calculate FFT only when fft_in are not all zero
|
||||
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
|
||||
const int offset = i * frame_step;
|
||||
|
||||
// apply Hann window (~10% faster)
|
||||
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
|
||||
fft_in[j] = hann[j] * samples[offset + j];
|
||||
}
|
||||
|
||||
// fill the rest with zeros
|
||||
if (n_samples - offset < frame_size) {
|
||||
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
|
||||
}
|
||||
|
||||
// FFT
|
||||
fft(fft_in.data(), frame_size, fft_out.data());
|
||||
|
||||
// Calculate modulus^2 of complex numbers
|
||||
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
|
||||
for (int j = 0; j < n_fft; j++) {
|
||||
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
|
||||
}
|
||||
|
||||
// mel spectrogram
|
||||
for (int j = 0; j < mel.n_mel; j++) {
|
||||
double sum = 0.0;
|
||||
// unroll loop (suggested by GH user @lunixbochs)
|
||||
int k = 0;
|
||||
for (k = 0; k < n_fft - 3; k += 4) {
|
||||
sum +=
|
||||
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
|
||||
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
|
||||
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
|
||||
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
|
||||
}
|
||||
// handle n_fft remainder
|
||||
for (; k < n_fft; k++) {
|
||||
sum += fft_out[k] * filters.data[j * n_fft + k];
|
||||
}
|
||||
sum = log10(std::max(sum, 1e-10));
|
||||
mel.data[j * mel.n_len + i] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise fft_out are all zero
|
||||
double sum = log10(1e-10);
|
||||
for (; i < mel.n_len; i += n_threads) {
|
||||
for (int j = 0; j < mel.n_mel; j++) {
|
||||
mel.data[j * mel.n_len + i] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
|
||||
static bool log_mel_spectrogram(
|
||||
const float * samples,
|
||||
const int n_samples,
|
||||
const int /*sample_rate*/,
|
||||
const int frame_size,
|
||||
const int frame_step,
|
||||
const int n_mel,
|
||||
const int n_threads,
|
||||
const whisper_filters & filters,
|
||||
const bool debug,
|
||||
whisper_mel & mel) {
|
||||
//const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
// Hann window
|
||||
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
|
||||
const float * hann = global_cache.hann_window;
|
||||
|
||||
// Calculate the length of padding
|
||||
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
|
||||
int64_t stage_2_pad = frame_size / 2;
|
||||
|
||||
// Initialize a vector and copy data from C array to it.
|
||||
std::vector<float> samples_padded;
|
||||
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
|
||||
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
|
||||
|
||||
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
|
||||
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
|
||||
|
||||
// reflective pad 200 samples at the beginning of audio
|
||||
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
|
||||
|
||||
mel.n_mel = n_mel;
|
||||
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
|
||||
// Calculate number of frames + remove the last frame
|
||||
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
|
||||
// Calculate semi-padded sample length to ensure compatibility
|
||||
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
|
||||
mel.data.resize(mel.n_mel * mel.n_len);
|
||||
|
||||
{
|
||||
std::vector<std::thread> workers(n_threads - 1);
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw] = std::thread(
|
||||
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
|
||||
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
|
||||
std::cref(filters), std::ref(mel));
|
||||
}
|
||||
|
||||
// main thread
|
||||
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
|
||||
|
||||
for (int iw = 0; iw < n_threads - 1; ++iw) {
|
||||
workers[iw].join();
|
||||
}
|
||||
}
|
||||
|
||||
// clamping and normalization
|
||||
double mmax = -1e20;
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] > mmax) {
|
||||
mmax = mel.data[i];
|
||||
}
|
||||
}
|
||||
|
||||
mmax -= 8.0;
|
||||
|
||||
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
|
||||
if (mel.data[i] < mmax) {
|
||||
mel.data[i] = mmax;
|
||||
}
|
||||
|
||||
mel.data[i] = (mel.data[i] + 4.0)/4.0;
|
||||
}
|
||||
|
||||
// Dump log_mel_spectrogram
|
||||
if (debug) {
|
||||
std::ofstream outFile("log_mel_spectrogram.json");
|
||||
outFile << "[";
|
||||
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
|
||||
outFile << mel.data[i] << ", ";
|
||||
}
|
||||
outFile << mel.data[mel.data.size() - 1] << "]";
|
||||
outFile.close();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool preprocess_audio(
|
||||
const float * samples,
|
||||
size_t n_samples,
|
||||
const whisper_filters & filters,
|
||||
std::vector<whisper_mel> & output) {
|
||||
|
||||
if (n_samples == 0) {
|
||||
// empty audio
|
||||
return false;
|
||||
}
|
||||
|
||||
whisper_mel out_full;
|
||||
bool ok = log_mel_spectrogram(
|
||||
samples,
|
||||
n_samples,
|
||||
COMMON_SAMPLE_RATE,
|
||||
WHISPER_N_FFT,
|
||||
WHISPER_HOP_LENGTH,
|
||||
filters.n_mel,
|
||||
4, // n_threads
|
||||
filters,
|
||||
false, // debug
|
||||
out_full);
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
|
||||
// we always expect the mel to have 3000 silent frames at the end
|
||||
// printf("n_len %d\n", out_full.n_len);
|
||||
const size_t frames_per_chunk = 3000;
|
||||
GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
|
||||
for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
|
||||
int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
|
||||
if ((size_t)n_len < frames_per_chunk) {
|
||||
break; // last uncomplete chunk will always be a padded chunk, safe to ignore
|
||||
}
|
||||
|
||||
whisper_mel out_chunk;
|
||||
out_chunk.n_len = n_len;
|
||||
out_chunk.n_mel = out_full.n_mel;
|
||||
out_chunk.n_len_org = out_full.n_mel; // unused
|
||||
out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
|
||||
|
||||
for (int i = 0; i < out_full.n_mel; i++) {
|
||||
auto src = out_full.data.begin() + i*out_full.n_len + off;
|
||||
out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
|
||||
}
|
||||
|
||||
output.push_back(std::move(out_chunk));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace whisper_preprocessor
|
||||
|
||||
|
||||
namespace audio_helpers {
|
||||
|
||||
bool is_audio_file(const char * buf, size_t len) {
|
||||
if (len < 12) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
|
||||
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
|
||||
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
|
||||
bool is_mp3 = len >= 3 && (
|
||||
memcmp(buf, "ID3", 3) == 0 ||
|
||||
// Check for MPEG sync word (simplified check)
|
||||
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
|
||||
);
|
||||
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
|
||||
|
||||
return is_wav || is_mp3 || is_flac;
|
||||
}
|
||||
|
||||
// returns true if the buffer is a valid audio file
|
||||
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
|
||||
ma_result result;
|
||||
const int channels = 1;
|
||||
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
|
||||
ma_decoder decoder;
|
||||
|
||||
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
|
||||
if (result != MA_SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ma_uint64 frame_count;
|
||||
ma_uint64 frames_read;
|
||||
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
|
||||
if (result != MA_SUCCESS) {
|
||||
ma_decoder_uninit(&decoder);
|
||||
return false;
|
||||
}
|
||||
|
||||
pcmf32_mono.resize(frame_count);
|
||||
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
|
||||
if (result != MA_SUCCESS) {
|
||||
ma_decoder_uninit(&decoder);
|
||||
return false;
|
||||
}
|
||||
|
||||
#ifdef MTMD_AUDIO_DEBUG
|
||||
// save audio to wav file
|
||||
ma_encoder_config config = ma_encoder_config_init(ma_encoding_format_wav, ma_format_f32, 1, target_sampler_rate);
|
||||
ma_encoder encoder;
|
||||
ma_encoder_init_file("output.wav", &config, &encoder);
|
||||
ma_encoder_write_pcm_frames(&encoder, pcmf32_mono.data(), pcmf32_mono.size(), &frames_read);
|
||||
ma_encoder_uninit(&encoder);
|
||||
#endif
|
||||
|
||||
ma_decoder_uninit(&decoder);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace wav_utils
|
||||
|
||||
|
||||
// precalculated mel filter banks
|
||||
// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
|
||||
//
|
||||
// generated from python code:
|
||||
//
|
||||
// from numpy import load
|
||||
// data = load('mel_filters.npz')
|
||||
// lst = data.files
|
||||
// for item in lst:
|
||||
// print(item)
|
||||
// print(data[item].shape)
|
||||
// n_mel = data[item].shape[0]
|
||||
// n_fft = data[item].shape[1]
|
||||
// for i, row in enumerate(data[item]):
|
||||
// for j, val in enumerate(row):
|
||||
// val = val * 1000.0
|
||||
// if val != 0:
|
||||
// print(f"data[{i*n_fft + j}] = {val:.6f};")
|
||||
|
||||
namespace whisper_precalc_filters {
|
||||
|
||||
whisper_preprocessor::whisper_filters get_128_bins() {
|
||||
whisper_preprocessor::whisper_filters filters;
|
||||
filters.n_mel = 128;
|
||||
filters.n_fft = 201;
|
||||
std::vector data(filters.n_mel * filters.n_fft, 0.0f);
|
||||
|
||||
data[1] = 12.37398665;
|
||||
data[202] = 30.39256483;
|
||||
data[404] = 24.74797331;
|
||||
data[605] = 18.01857911;
|
||||
data[807] = 37.12195903;
|
||||
data[1008] = 5.64459199;
|
||||
data[1009] = 6.72939420;
|
||||
data[1210] = 36.03715822;
|
||||
data[1412] = 19.10337992;
|
||||
data[1613] = 23.66316877;
|
||||
data[1815] = 31.47736564;
|
||||
data[2016] = 11.28918398;
|
||||
data[2017] = 1.08480197;
|
||||
data[2218] = 41.68175161;
|
||||
data[2420] = 13.45878839;
|
||||
data[2621] = 29.30776216;
|
||||
data[2823] = 25.83277412;
|
||||
data[3024] = 16.93377644;
|
||||
data[3226] = 38.20675984;
|
||||
data[3427] = 4.55979025;
|
||||
data[3428] = 7.81419594;
|
||||
data[3629] = 34.95235741;
|
||||
data[3831] = 20.18818259;
|
||||
data[4032] = 22.57836796;
|
||||
data[4234] = 32.56217018;
|
||||
data[4435] = 10.20438317;
|
||||
data[4436] = 2.16960395;
|
||||
data[4637] = 40.59694707;
|
||||
data[4839] = 14.54358920;
|
||||
data[5040] = 28.22295949;
|
||||
data[5242] = 26.91757679;
|
||||
data[5443] = 15.84897563;
|
||||
data[5645] = 39.29156065;
|
||||
data[5846] = 3.47498828;
|
||||
data[5847] = 8.89899861;
|
||||
data[6048] = 33.86755288;
|
||||
data[6250] = 21.27298526;
|
||||
data[6451] = 21.49356715;
|
||||
data[6653] = 33.64697099;
|
||||
data[6854] = 9.11958050;
|
||||
data[6855] = 3.25440569;
|
||||
data[7056] = 39.51214626;
|
||||
data[7258] = 15.62839188;
|
||||
data[7459] = 27.13815868;
|
||||
data[7661] = 28.00237760;
|
||||
data[7862] = 14.76417296;
|
||||
data[8064] = 40.37636518;
|
||||
data[8265] = 2.38068704;
|
||||
data[8266] = 10.20263787;
|
||||
data[8467] = 31.61146119;
|
||||
data[8669] = 24.54700135;
|
||||
data[8870] = 15.32919332;
|
||||
data[8871] = 1.66583748;
|
||||
data[9072] = 36.72905266;
|
||||
data[9274] = 20.09709924;
|
||||
data[9475] = 16.93102531;
|
||||
data[9476] = 2.90265540;
|
||||
data[9677] = 32.84499049;
|
||||
data[9879] = 23.52004871;
|
||||
data[10080] = 11.03894413;
|
||||
data[10081] = 10.72582975;
|
||||
data[10282] = 22.71829173;
|
||||
data[10484] = 32.27872774;
|
||||
data[10685] = 0.11626833;
|
||||
data[10686] = 22.85348251;
|
||||
data[10887] = 8.56344029;
|
||||
data[10888] = 14.97978810;
|
||||
data[11089] = 15.51398356;
|
||||
data[11090] = 8.51490628;
|
||||
data[11291] = 21.10680379;
|
||||
data[11292] = 3.32652032;
|
||||
data[11493] = 25.47064796;
|
||||
data[11695] = 27.35907957;
|
||||
data[11896] = 0.65853616;
|
||||
data[11897] = 23.83812517;
|
||||
data[12098] = 3.44359246;
|
||||
data[12099] = 21.22455277;
|
||||
data[12300] = 5.35842171;
|
||||
data[12301] = 19.42555793;
|
||||
data[12502] = 6.49324711;
|
||||
data[12503] = 18.35542172;
|
||||
data[12704] = 6.93138083;
|
||||
data[12705] = 17.93504693;
|
||||
data[12906] = 6.74968259;
|
||||
data[12907] = 18.09151843;
|
||||
data[13108] = 6.01899112;
|
||||
data[13109] = 18.75767298;
|
||||
data[13310] = 4.80452832;
|
||||
data[13311] = 19.87172849;
|
||||
data[13512] = 3.16627859;
|
||||
data[13513] = 21.37690969;
|
||||
data[13514] = 1.25317345;
|
||||
data[13714] = 1.15934468;
|
||||
data[13715] = 20.80361731;
|
||||
data[13716] = 4.04486805;
|
||||
data[13917] = 17.55363122;
|
||||
data[13918] = 7.08320038;
|
||||
data[14119] = 14.07538634;
|
||||
data[14120] = 10.32655034;
|
||||
data[14321] = 10.40921453;
|
||||
data[14322] = 13.73696327;
|
||||
data[14523] = 6.59187697;
|
||||
data[14524] = 17.27988198;
|
||||
data[14525] = 1.46804214;
|
||||
data[14725] = 2.65681883;
|
||||
data[14726] = 18.09193194;
|
||||
data[14727] = 5.85655728;
|
||||
data[14928] = 13.34277913;
|
||||
data[14929] = 10.28267574;
|
||||
data[15130] = 8.56800377;
|
||||
data[15131] = 14.72230814;
|
||||
data[15132] = 1.04039861;
|
||||
data[15332] = 3.79085587;
|
||||
data[15333] = 17.14678481;
|
||||
data[15334] = 6.11609267;
|
||||
data[15535] = 11.75929047;
|
||||
data[15536] = 11.13393717;
|
||||
data[15737] = 6.43857848;
|
||||
data[15738] = 16.07806236;
|
||||
data[15739] = 4.23917221;
|
||||
data[15939] = 1.19989377;
|
||||
data[15940] = 12.75671553;
|
||||
data[15941] = 9.65298992;
|
||||
data[16142] = 7.06935255;
|
||||
data[16143] = 14.94054683;
|
||||
data[16144] = 4.19024844;
|
||||
data[16344] = 1.51483389;
|
||||
data[16345] = 12.00899947;
|
||||
data[16346] = 9.84823331;
|
||||
data[16547] = 6.10224018;
|
||||
data[16548] = 15.33857174;
|
||||
data[16549] = 5.57676842;
|
||||
data[16749] = 0.36827257;
|
||||
data[16750] = 9.89749376;
|
||||
data[16751] = 11.35340426;
|
||||
data[16752] = 2.05122307;
|
||||
data[16952] = 3.89297144;
|
||||
data[16953] = 12.97352277;
|
||||
data[16954] = 8.06631614;
|
||||
data[17155] = 6.74493238;
|
||||
data[17156] = 13.85874674;
|
||||
data[17157] = 5.41190524;
|
||||
data[17357] = 0.74220158;
|
||||
data[17358] = 8.98779090;
|
||||
data[17359] = 11.37871388;
|
||||
data[17360] = 3.32958088;
|
||||
data[17560] = 2.82313535;
|
||||
data[17561] = 10.68049297;
|
||||
data[17562] = 9.43340641;
|
||||
data[17563] = 1.76325557;
|
||||
data[17763] = 4.39018616;
|
||||
data[17764] = 11.87758986;
|
||||
data[17765] = 7.97005836;
|
||||
data[17766] = 0.66104700;
|
||||
data[17966] = 5.49466675;
|
||||
data[17967] = 12.62953598;
|
||||
data[17968] = 6.93987962;
|
||||
data[18169] = 6.18401915;
|
||||
data[18170] = 12.93473132;
|
||||
data[18171] = 6.29778765;
|
||||
data[18371] = 0.02325210;
|
||||
data[18372] = 6.50206627;
|
||||
data[18373] = 12.32661773;
|
||||
data[18374] = 6.00216538;
|
||||
data[18574] = 0.31548753;
|
||||
data[18575] = 6.48925547;
|
||||
data[18576] = 12.04130240;
|
||||
data[18577] = 6.01462880;
|
||||
data[18777] = 0.29979556;
|
||||
data[18778] = 6.18288014;
|
||||
data[18779] = 12.04272825;
|
||||
data[18780] = 6.29981188;
|
||||
data[18781] = 0.55689598;
|
||||
data[18980] = 0.01120471;
|
||||
data[18981] = 5.61729167;
|
||||
data[18982] = 11.22337859;
|
||||
data[18983] = 6.82516303;
|
||||
data[18984] = 1.35264499;
|
||||
data[19184] = 4.82410006;
|
||||
data[19185] = 10.16623247;
|
||||
data[19186] = 7.56075513;
|
||||
data[19187] = 2.34590308;
|
||||
data[19387] = 3.83235747;
|
||||
data[19388] = 8.92296247;
|
||||
data[19389] = 8.47910438;
|
||||
data[19390] = 3.50978645;
|
||||
data[19590] = 2.66873185;
|
||||
data[19591] = 7.51965167;
|
||||
data[19592] = 9.55500547;
|
||||
data[19593] = 4.81966138;
|
||||
data[19594] = 0.08431751;
|
||||
data[19793] = 1.35767367;
|
||||
data[19794] = 5.98019501;
|
||||
data[19795] = 10.60271543;
|
||||
data[19796] = 6.25298498;
|
||||
data[19797] = 1.74059917;
|
||||
data[19997] = 4.32644226;
|
||||
data[19998] = 8.73131864;
|
||||
data[19999] = 7.78916525;
|
||||
data[20000] = 3.48923868;
|
||||
data[20200] = 2.57835095;
|
||||
data[20201] = 6.77582854;
|
||||
data[20202] = 9.40941647;
|
||||
data[20203] = 5.31194592;
|
||||
data[20204] = 1.21447595;
|
||||
data[20403] = 0.75411191;
|
||||
data[20404] = 4.75395704;
|
||||
data[20405] = 8.75380263;
|
||||
data[20406] = 7.19209015;
|
||||
data[20407] = 3.28754401;
|
||||
data[20607] = 2.68179690;
|
||||
data[20608] = 6.49331464;
|
||||
data[20609] = 9.11457930;
|
||||
data[20610] = 5.39387390;
|
||||
data[20611] = 1.67316827;
|
||||
data[20810] = 0.57394296;
|
||||
data[20811] = 4.20600036;
|
||||
data[20812] = 7.83805829;
|
||||
data[20813] = 7.52023002;
|
||||
data[20814] = 3.97470826;
|
||||
data[20815] = 0.42918732;
|
||||
data[21014] = 1.90464477;
|
||||
data[21015] = 5.36569161;
|
||||
data[21016] = 8.82673822;
|
||||
data[21017] = 6.27609482;
|
||||
data[21018] = 2.89750961;
|
||||
data[21218] = 2.89885257;
|
||||
data[21219] = 6.19694078;
|
||||
data[21220] = 8.56699049;
|
||||
data[21221] = 5.34748193;
|
||||
data[21222] = 2.12797290;
|
||||
data[21421] = 0.44750227;
|
||||
data[21422] = 3.59030394;
|
||||
data[21423] = 6.73310598;
|
||||
data[21424] = 7.77023612;
|
||||
data[21425] = 4.70231380;
|
||||
data[21426] = 1.63439126;
|
||||
data[21625] = 1.01536023;
|
||||
data[21626] = 4.01018746;
|
||||
data[21627] = 7.00501446;
|
||||
data[21628] = 7.23442994;
|
||||
data[21629] = 4.31095669;
|
||||
data[21630] = 1.38748321;
|
||||
data[21829] = 1.33348850;
|
||||
data[21830] = 4.18730825;
|
||||
data[21831] = 7.04112789;
|
||||
data[21832] = 6.93188375;
|
||||
data[21833] = 4.14605811;
|
||||
data[21834] = 1.36023236;
|
||||
data[22033] = 1.42879714;
|
||||
data[22034] = 4.14824858;
|
||||
data[22035] = 6.86769979;
|
||||
data[22036] = 6.83705276;
|
||||
data[22037] = 4.18239459;
|
||||
data[22038] = 1.52773573;
|
||||
data[22237] = 1.32610439;
|
||||
data[22238] = 3.91751388;
|
||||
data[22239] = 6.50892360;
|
||||
data[22240] = 6.92639686;
|
||||
data[22241] = 4.39672917;
|
||||
data[22242] = 1.86706171;
|
||||
data[22441] = 1.04827771;
|
||||
data[22442] = 3.51767405;
|
||||
data[22443] = 5.98707050;
|
||||
data[22444] = 7.17824046;
|
||||
data[22445] = 4.76767914;
|
||||
data[22446] = 2.35711760;
|
||||
data[22645] = 0.61636406;
|
||||
data[22646] = 2.96949223;
|
||||
data[22647] = 5.32262027;
|
||||
data[22648] = 7.57265091;
|
||||
data[22649] = 5.27558755;
|
||||
data[22650] = 2.97852419;
|
||||
data[22651] = 0.68146095;
|
||||
data[22849] = 0.04971400;
|
||||
data[22850] = 2.29204819;
|
||||
data[22851] = 4.53438237;
|
||||
data[22852] = 6.77671656;
|
||||
data[22853] = 5.90240723;
|
||||
data[22854] = 3.71349836;
|
||||
data[22855] = 1.52458926;
|
||||
data[23054] = 1.50285335;
|
||||
data[23055] = 3.63961048;
|
||||
data[23056] = 5.77636715;
|
||||
data[23057] = 6.63159089;
|
||||
data[23058] = 4.54574358;
|
||||
data[23059] = 2.45989650;
|
||||
data[23060] = 0.37404924;
|
||||
data[23258] = 0.61795861;
|
||||
data[23259] = 2.65410915;
|
||||
data[23260] = 4.69025923;
|
||||
data[23261] = 6.72641024;
|
||||
data[23262] = 5.46034705;
|
||||
data[23263] = 3.47270933;
|
||||
data[23264] = 1.48507138;
|
||||
data[23463] = 1.59233576;
|
||||
data[23464] = 3.53261665;
|
||||
data[23465] = 5.47289755;
|
||||
data[23466] = 6.44368259;
|
||||
data[23467] = 4.54962999;
|
||||
data[23468] = 2.65557761;
|
||||
data[23469] = 0.76152512;
|
||||
data[23667] = 0.46749352;
|
||||
data[23668] = 2.31641904;
|
||||
data[23669] = 4.16534441;
|
||||
data[23670] = 6.01426978;
|
||||
data[23671] = 5.67844696;
|
||||
data[23672] = 3.87357362;
|
||||
data[23673] = 2.06870004;
|
||||
data[23674] = 0.26382666;
|
||||
data[23872] = 1.05349103;
|
||||
data[23873] = 2.81536230;
|
||||
data[23874] = 4.57723346;
|
||||
data[23875] = 6.33910485;
|
||||
data[23876] = 5.12815686;
|
||||
data[23877] = 3.40826320;
|
||||
data[23878] = 1.68837002;
|
||||
data[24077] = 1.43350090;
|
||||
data[24078] = 3.11241671;
|
||||
data[24079] = 4.79133241;
|
||||
data[24080] = 6.40943693;
|
||||
data[24081] = 4.77052201;
|
||||
data[24082] = 3.13160778;
|
||||
data[24083] = 1.49269309;
|
||||
data[24281] = 0.02932359;
|
||||
data[24282] = 1.62918994;
|
||||
data[24283] = 3.22905602;
|
||||
data[24284] = 4.82892245;
|
||||
data[24285] = 6.14671456;
|
||||
data[24286] = 4.58496623;
|
||||
data[24287] = 3.02321767;
|
||||
data[24288] = 1.46146910;
|
||||
data[24486] = 0.13601698;
|
||||
data[24487] = 1.66055572;
|
||||
data[24488] = 3.18509457;
|
||||
data[24489] = 4.70963307;
|
||||
data[24490] = 6.04072399;
|
||||
data[24491] = 4.55250870;
|
||||
data[24492] = 3.06429295;
|
||||
data[24493] = 1.57607743;
|
||||
data[24494] = 0.08786193;
|
||||
data[24691] = 0.09328097;
|
||||
data[24692] = 1.54603878;
|
||||
data[24693] = 2.99879676;
|
||||
data[24694] = 4.45155473;
|
||||
data[24695] = 5.90431225;
|
||||
data[24696] = 4.65566106;
|
||||
data[24697] = 3.23751615;
|
||||
data[24698] = 1.81937125;
|
||||
data[24699] = 0.40122634;
|
||||
data[24897] = 1.30262633;
|
||||
data[24898] = 2.68698297;
|
||||
data[24899] = 4.07133950;
|
||||
data[24900] = 5.45569602;
|
||||
data[24901] = 4.87832492;
|
||||
data[24902] = 3.52695142;
|
||||
data[24903] = 2.17557792;
|
||||
data[24904] = 0.82420459;
|
||||
data[25102] = 0.94595028;
|
||||
data[25103] = 2.26512621;
|
||||
data[25104] = 3.58430226;
|
||||
data[25105] = 4.90347855;
|
||||
data[25106] = 5.20569785;
|
||||
data[25107] = 3.91795207;
|
||||
data[25108] = 2.63020652;
|
||||
data[25109] = 1.34246063;
|
||||
data[25110] = 0.05471494;
|
||||
data[25307] = 0.49037894;
|
||||
data[25308] = 1.74744334;
|
||||
data[25309] = 3.00450763;
|
||||
data[25310] = 4.26157191;
|
||||
data[25311] = 5.51863620;
|
||||
data[25312] = 4.39707236;
|
||||
data[25313] = 3.16995848;
|
||||
data[25314] = 1.94284460;
|
||||
data[25315] = 0.71573065;
|
||||
data[25513] = 1.14698056;
|
||||
data[25514] = 2.34485767;
|
||||
data[25515] = 3.54273478;
|
||||
data[25516] = 4.74061165;
|
||||
data[25517] = 4.95198462;
|
||||
data[25518] = 3.78264743;
|
||||
data[25519] = 2.61331047;
|
||||
data[25520] = 1.44397374;
|
||||
data[25521] = 0.27463681;
|
||||
data[25718] = 0.47569509;
|
||||
data[25719] = 1.61717169;
|
||||
data[25720] = 2.75864848;
|
||||
data[25721] = 3.90012516;
|
||||
data[25722] = 5.04160160;
|
||||
data[25723] = 4.45712078;
|
||||
data[25724] = 3.34284059;
|
||||
data[25725] = 2.22856039;
|
||||
data[25726] = 1.11428020;
|
||||
|
||||
for (auto & val : data) {
|
||||
val /= 1000.0f;
|
||||
}
|
||||
|
||||
filters.data = std::move(data);
|
||||
return filters;
|
||||
}
|
||||
|
||||
} // namespace whisper_precalc_filters
|
||||
@@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#define WHISPER_ASSERT GGML_ASSERT
|
||||
|
||||
#define WHISPER_SAMPLE_RATE 16000
|
||||
#define WHISPER_N_FFT 400
|
||||
#define WHISPER_HOP_LENGTH 160
|
||||
#define WHISPER_CHUNK_SIZE 30
|
||||
|
||||
#define COMMON_SAMPLE_RATE 16000
|
||||
|
||||
namespace whisper_preprocessor {
|
||||
|
||||
struct whisper_mel {
|
||||
int n_len;
|
||||
int n_len_org;
|
||||
int n_mel;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
struct whisper_filters {
|
||||
int32_t n_mel;
|
||||
int32_t n_fft;
|
||||
|
||||
std::vector<float> data;
|
||||
};
|
||||
|
||||
extern bool preprocess_audio(
|
||||
const float * samples,
|
||||
size_t n_samples,
|
||||
const whisper_filters & filters,
|
||||
std::vector<whisper_mel> & output);
|
||||
|
||||
} // namespace whisper_preprocessor
|
||||
|
||||
|
||||
// TODO @ngxson : move this helper to mtmd-helpers.cpp
|
||||
namespace audio_helpers {
|
||||
|
||||
extern bool is_audio_file(const char * buf, size_t len);
|
||||
|
||||
extern bool decode_audio_from_buf(
|
||||
const unsigned char * buf_in,
|
||||
size_t len,
|
||||
int target_sampler_rate,
|
||||
std::vector<float> & pcmf32_mono);
|
||||
|
||||
} // namespace audio_helpers
|
||||
|
||||
|
||||
namespace whisper_precalc_filters {
|
||||
|
||||
extern whisper_preprocessor::whisper_filters get_128_bins();
|
||||
|
||||
} // namespace whisper_precalc_filters
|
||||
+21
-14
@@ -37,10 +37,10 @@ static volatile bool g_is_interrupted = false;
|
||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||
LOG(
|
||||
"Experimental CLI for multimodal\n\n"
|
||||
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
|
||||
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> --audio <audio> -p <prompt>\n\n"
|
||||
" -m and --mmproj are required\n"
|
||||
" -hf user/repo can replace both -m and --mmproj in most cases\n"
|
||||
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
|
||||
" --image, --audio and -p are optional, if NOT provided, the CLI will run in chat mode\n"
|
||||
" to disable using GPU for mmproj model, add --no-mmproj-offload\n",
|
||||
argv[0]
|
||||
);
|
||||
@@ -142,7 +142,7 @@ struct mtmd_cli_context {
|
||||
);
|
||||
}
|
||||
|
||||
bool load_image(const std::string & fname) {
|
||||
bool load_media(const std::string & fname) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(fname.c_str()));
|
||||
if (!bmp.ptr) {
|
||||
return false;
|
||||
@@ -243,7 +243,7 @@ int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
params.sampling.temp = 0.2; // lower temp by default for better quality
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -283,14 +283,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
if (is_single_turn) {
|
||||
g_is_generating = true;
|
||||
if (params.prompt.find("<__image__>") == std::string::npos) {
|
||||
params.prompt += " <__image__>";
|
||||
if (params.prompt.find(mtmd_default_marker()) == std::string::npos) {
|
||||
params.prompt += mtmd_default_marker();
|
||||
}
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = params.prompt;
|
||||
for (const auto & image : params.image) {
|
||||
if (!ctx.load_image(image)) {
|
||||
if (!ctx.load_media(image)) {
|
||||
return 1; // error is already printed by libmtmd
|
||||
}
|
||||
}
|
||||
@@ -303,7 +303,12 @@ int main(int argc, char ** argv) {
|
||||
|
||||
} else {
|
||||
LOG("\n Running in chat mode, available commands:");
|
||||
LOG("\n /image <path> load an image");
|
||||
if (mtmd_support_vision(ctx.ctx_vision.get())) {
|
||||
LOG("\n /image <path> load an image");
|
||||
}
|
||||
if (mtmd_support_audio(ctx.ctx_vision.get())) {
|
||||
LOG("\n /audio <path> load an audio");
|
||||
}
|
||||
LOG("\n /clear clear the chat history");
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
@@ -333,15 +338,17 @@ int main(int argc, char ** argv) {
|
||||
continue;
|
||||
}
|
||||
g_is_generating = true;
|
||||
if (line == "/image" || line.find("/image ") == 0) {
|
||||
bool is_image = line == "/image" || line.find("/image ") == 0;
|
||||
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
|
||||
if (is_image || is_audio) {
|
||||
if (line.size() < 8) {
|
||||
LOG_ERR("ERR: Missing image filename\n");
|
||||
LOG_ERR("ERR: Missing media filename\n");
|
||||
continue;
|
||||
}
|
||||
std::string image = line.substr(7);
|
||||
if (ctx.load_image(image)) {
|
||||
LOG("Image %s loaded\n", image.c_str());
|
||||
content += "<__image__>";
|
||||
std::string media_path = line.substr(7);
|
||||
if (ctx.load_media(media_path)) {
|
||||
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
|
||||
content += mtmd_default_marker();
|
||||
}
|
||||
// else, error is already printed by libmtmd
|
||||
continue;
|
||||
|
||||
+29
-44
@@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) {
|
||||
size_t n_tokens = 0;
|
||||
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
|
||||
auto chunk = mtmd_input_chunks_get(chunks, i);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens_text;
|
||||
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
|
||||
n_tokens += n_tokens_text;
|
||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image);
|
||||
} else {
|
||||
GGML_ASSERT(false && "chunk type not supported");
|
||||
}
|
||||
n_tokens += mtmd_input_chunk_get_n_tokens(chunk);
|
||||
}
|
||||
return n_tokens;
|
||||
}
|
||||
@@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) {
|
||||
llama_pos n_pos = 0;
|
||||
for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) {
|
||||
auto chunk = mtmd_input_chunks_get(chunks, i);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens_text;
|
||||
mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text);
|
||||
n_pos += n_tokens_text;
|
||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
n_pos += mtmd_image_tokens_get_n_pos(tokens_image);
|
||||
} else {
|
||||
GGML_ASSERT(false && "chunk type not supported");
|
||||
}
|
||||
n_pos += mtmd_input_chunk_get_n_pos(chunk);
|
||||
}
|
||||
return n_pos;
|
||||
}
|
||||
@@ -149,13 +129,10 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past) {
|
||||
if (mtmd_input_chunk_get_type(chunk) != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
LOG_ERR("failed to decode image chunk: input chunk not of image type\n");
|
||||
return -1;
|
||||
}
|
||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
if (!image_tokens) {
|
||||
LOG_ERR("failed to decode image chunk: image tokens are null\n");
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
LOG_ERR("failed to decode chunk: input chunk not of image/audio type\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -163,15 +140,23 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
int n_mmproj_embd = llama_model_n_embd(model);
|
||||
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
|
||||
|
||||
int32_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
|
||||
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
|
||||
int32_t i_batch = 0;
|
||||
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
|
||||
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
|
||||
|
||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||
|
||||
if (mtmd_decode_use_mrope(ctx)) {
|
||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
if (chunk_type != MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
LOG_ERR("failed to decode chunk: M-RoPE only accepts image chunk\n");
|
||||
return -1;
|
||||
}
|
||||
if (!image_tokens) {
|
||||
LOG_ERR("failed to decode chunk: image tokens are null\n");
|
||||
return -1;
|
||||
}
|
||||
const int nx = mtmd_image_tokens_get_nx(image_tokens);
|
||||
const int ny = mtmd_image_tokens_get_ny(image_tokens);
|
||||
batch_embd.set_position_mrope(n_past, nx, ny, seq_id);
|
||||
} else {
|
||||
batch_embd.set_position_normal(n_past, seq_id);
|
||||
@@ -187,22 +172,22 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
int n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
|
||||
llama_batch batch_embd_view = batch_embd.get_view(pos_offset, n_tokens_batch);
|
||||
|
||||
LOG_INF("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
|
||||
LOG_INF("decoding %s batch %d/%d, n_tokens_batch = %d\n", name, i_batch+1, n_img_batches, n_tokens_batch);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int32_t ret = llama_decode(lctx, batch_embd_view);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
LOG_ERR("failed to decode %s\n", name);
|
||||
llama_set_causal_attn(lctx, true); // restore causal attn
|
||||
return ret;
|
||||
}
|
||||
|
||||
LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
LOG_INF("%s decoded (batch %d/%d) in %" PRId64 " ms\n", name, i_batch+1, n_img_batches, ggml_time_ms() - t1);
|
||||
|
||||
i_batch++;
|
||||
}
|
||||
|
||||
n_past += mtmd_image_tokens_get_n_pos(image_tokens);
|
||||
n_past += mtmd_input_chunk_get_n_pos(chunk);
|
||||
*new_n_past = n_past;
|
||||
|
||||
if (mtmd_decode_use_non_causal(ctx)) {
|
||||
@@ -253,25 +238,25 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
*new_n_past += text_batch.n_tokens;
|
||||
}
|
||||
|
||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
const auto image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
} else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE || chunk_type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||
int64_t t0 = ggml_time_ms();
|
||||
|
||||
LOG_INF("encoding image or slice...\n");
|
||||
LOG_INF("encoding %s slice...\n", name);
|
||||
|
||||
ret = mtmd_encode(ctx, image_tokens);
|
||||
ret = mtmd_encode_chunk(ctx, chunk);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to encode image\n");
|
||||
LOG_ERR("failed to encode %s slice\n", name);
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
LOG_INF("%s slice encoded in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||
|
||||
float * embd = mtmd_get_output_embd(ctx);
|
||||
ret = mtmd_helper_decode_image_chunk(ctx, lctx, chunk, embd, n_past, seq_id, n_batch, new_n_past);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
LOG_ERR("failed to decode %s\n", name);
|
||||
llama_batch_free(text_batch);
|
||||
return ret;
|
||||
}
|
||||
|
||||
+306
-73
@@ -1,6 +1,7 @@
|
||||
#include "clip.h"
|
||||
#include "clip-impl.h"
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-audio.h"
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
@@ -19,17 +20,49 @@ struct mtmd_bitmap {
|
||||
uint32_t ny;
|
||||
std::vector<unsigned char> data;
|
||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||
bool is_audio = false; // true if the bitmap is audio
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens_deleter {
|
||||
void operator()(mtmd_image_tokens * val); // forward declaration
|
||||
struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
mtmd_image_tokens clone() {
|
||||
return mtmd_image_tokens{
|
||||
nx,
|
||||
ny,
|
||||
use_mrope_pos,
|
||||
batch_f32.clone(),
|
||||
id
|
||||
};
|
||||
}
|
||||
};
|
||||
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens, mtmd_image_tokens_deleter>;
|
||||
using mtmd_image_tokens_ptr = std::unique_ptr<mtmd_image_tokens>;
|
||||
|
||||
struct mtmd_audio_tokens {
|
||||
uint32_t n_tokens; // number of tokens
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
mtmd_audio_tokens clone() {
|
||||
return mtmd_audio_tokens{
|
||||
n_tokens,
|
||||
batch_f32.clone(),
|
||||
id
|
||||
};
|
||||
}
|
||||
};
|
||||
using mtmd_audio_tokens_ptr = std::unique_ptr<mtmd_audio_tokens>;
|
||||
|
||||
struct mtmd_input_chunk {
|
||||
mtmd_input_chunk_type type;
|
||||
std::vector<llama_token> tokens_text;
|
||||
mtmd_image_tokens_ptr tokens_image;
|
||||
mtmd_audio_tokens_ptr tokens_audio;
|
||||
};
|
||||
|
||||
struct mtmd_input_chunks {
|
||||
@@ -46,6 +79,10 @@ enum mtmd_slice_tmpl {
|
||||
// TODO @ngxson : add support for idefics (SmolVLM)
|
||||
};
|
||||
|
||||
const char * mtmd_default_marker() {
|
||||
return "<__media__>";
|
||||
}
|
||||
|
||||
mtmd_context_params mtmd_context_params_default() {
|
||||
mtmd_context_params params;
|
||||
params.use_gpu = true;
|
||||
@@ -53,6 +90,7 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
params.n_threads = 4;
|
||||
params.verbosity = GGML_LOG_LEVEL_INFO;
|
||||
params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
params.media_marker = mtmd_default_marker();
|
||||
return params;
|
||||
}
|
||||
|
||||
@@ -63,7 +101,9 @@ struct mtmd_context {
|
||||
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
std::string image_marker;
|
||||
std::string media_marker;
|
||||
bool has_vision;
|
||||
bool has_audio;
|
||||
|
||||
// for llava-uhd style models, we need special tokens in-between slices
|
||||
// minicpmv calls them "slices", llama 4 calls them "tiles"
|
||||
@@ -81,6 +121,9 @@ struct mtmd_context {
|
||||
|
||||
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
|
||||
|
||||
// for whisper, we pre-calculate the mel filter bank
|
||||
whisper_preprocessor::whisper_filters w_filters;
|
||||
|
||||
// TODO @ngxson : add timings
|
||||
|
||||
mtmd_context(const char * mmproj_fname,
|
||||
@@ -89,8 +132,12 @@ struct mtmd_context {
|
||||
text_model (text_model),
|
||||
print_timings(ctx_params.print_timings),
|
||||
n_threads (ctx_params.n_threads),
|
||||
image_marker (ctx_params.image_marker)
|
||||
media_marker (ctx_params.media_marker)
|
||||
{
|
||||
if (std::string(ctx_params.image_marker) != MTMD_DEFAULT_IMAGE_MARKER) {
|
||||
throw std::runtime_error("custom image_marker is not supported anymore, use media_marker instead");
|
||||
}
|
||||
|
||||
clip_context_params ctx_clip_params;
|
||||
ctx_clip_params.use_gpu = ctx_params.use_gpu;
|
||||
ctx_clip_params.verbosity = ctx_params.verbosity;
|
||||
@@ -99,7 +146,16 @@ struct mtmd_context {
|
||||
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
|
||||
}
|
||||
|
||||
use_mrope = clip_is_qwen2vl(ctx_clip);
|
||||
if (llama_model_n_embd(text_model) != clip_n_mmproj_embd(ctx_clip)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"mismatch between text model (n_embd = %d) and mmproj (n_embd = %d)\n"
|
||||
"hint: you may be using wrong mmproj\n",
|
||||
llama_model_n_embd(text_model), clip_n_mmproj_embd(ctx_clip)));
|
||||
}
|
||||
|
||||
has_vision = clip_has_vision_encoder(ctx_clip);
|
||||
has_audio = clip_has_audio_encoder(ctx_clip);
|
||||
use_mrope = clip_is_qwen2vl(ctx_clip);
|
||||
|
||||
projector_type proj = clip_get_projector_type(ctx_clip);
|
||||
int minicpmv_version = clip_is_minicpmv(ctx_clip);
|
||||
@@ -146,6 +202,21 @@ struct mtmd_context {
|
||||
tok_row_end_trail = true; // add trailing end-of-row token
|
||||
ov_img_first = false; // overview image is last
|
||||
}
|
||||
|
||||
if (clip_has_whisper_encoder(ctx_clip)) {
|
||||
// TODO @ngxson : check if model n_mel is 128 or 80
|
||||
w_filters = whisper_precalc_filters::get_128_bins();
|
||||
}
|
||||
|
||||
// warning messages
|
||||
if (proj == PROJECTOR_TYPE_LLAMA4) {
|
||||
LOG_WRN("%s: llama 4 vision is known to have degraded quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
|
||||
}
|
||||
if (has_audio) {
|
||||
LOG_WRN("%s: audio input is in experimental stage and may have reduced quality:\n"
|
||||
" https://github.com/ggml-org/llama.cpp/discussions/13759\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
~mtmd_context() {
|
||||
@@ -179,29 +250,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens_data {
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
};
|
||||
|
||||
struct mtmd_image_tokens {
|
||||
uint32_t nx; // number of tokens in x direction
|
||||
uint32_t ny; // number of tokens in y direction
|
||||
bool use_mrope_pos = false; // use M-RoPE position counting (the whole image is 1 temporal position)
|
||||
uint32_t n_tokens() const { return nx * ny; }
|
||||
clip_image_f32_batch batch_f32; // preprocessed image patches
|
||||
std::string id; // optional user-defined ID, useful for KV cache tracking
|
||||
|
||||
mtmd_image_tokens clone() {
|
||||
return mtmd_image_tokens{
|
||||
nx,
|
||||
ny,
|
||||
use_mrope_pos,
|
||||
batch_f32.clone(),
|
||||
id
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
|
||||
const struct llama_model * text_model,
|
||||
const struct mtmd_context_params ctx_params) {
|
||||
@@ -247,59 +295,68 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
auto vocab = llama_model_get_vocab(ctx->text_model);
|
||||
|
||||
std::string prompt_modified(text->text);
|
||||
std::string marker_modified(ctx->image_marker);
|
||||
std::string marker_modified(ctx->media_marker);
|
||||
projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
|
||||
|
||||
// for compatibility, we convert image marker to media marker
|
||||
string_replace_all(prompt_modified, MTMD_DEFAULT_IMAGE_MARKER, ctx->media_marker);
|
||||
|
||||
// a bit hacky here, but works for now
|
||||
// for some models, we need to add prefix and suffix to the image embeddings
|
||||
if (clip_is_gemma3(ctx->ctx_clip)) {
|
||||
// gemma 3
|
||||
// <start_of_image> ... (image embeddings) ... <end_of_image>
|
||||
marker_modified = "<start_of_image>" + ctx->image_marker + "<end_of_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = "<start_of_image>" + ctx->media_marker + "<end_of_image>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->media_marker + "<fake_token_around_image>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||
marker_modified = ctx->image_marker + "[IMG_END]";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = ctx->media_marker + "[IMG_END]";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
|
||||
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
|
||||
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = "<|vision_start|>" + ctx->media_marker + "<|vision_end|>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
|
||||
// (more details in mtmd_context constructor)
|
||||
marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = "<|image_start|>" + ctx->media_marker + "<|image_end|>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
|
||||
// <img> ... (image embeddings) ... </img>
|
||||
marker_modified = "<img>" + ctx->image_marker + "</img>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
marker_modified = "<img>" + ctx->media_marker + "</img>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_QWEN2A) {
|
||||
// <|audio_bos|> ... (embeddings) ... <|audio_eos|>
|
||||
marker_modified = "<|audio_bos|>" + ctx->media_marker + "<|audio_eos|>";
|
||||
string_replace_all(prompt_modified, ctx->media_marker, marker_modified);
|
||||
|
||||
}
|
||||
|
||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||
// for glm-edge, BOI and EOI token's embeddings are not present in the text model
|
||||
|
||||
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->image_marker);
|
||||
std::vector<std::string> parts = string_split_str(prompt_modified, ctx->media_marker);
|
||||
output->entries.clear();
|
||||
output->entries.reserve(parts.size());
|
||||
|
||||
size_t i_img = 0;
|
||||
size_t i_bm = 0;
|
||||
|
||||
// utility for adding raw tokens
|
||||
auto add_text_chunk = [&output](std::vector<llama_token> && tokens) {
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::move(tokens),
|
||||
{},
|
||||
nullptr, // image tokens
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
output->entries.emplace_back(std::move(chunk));
|
||||
};
|
||||
@@ -317,8 +374,9 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
{}, // text tokens
|
||||
std::move(image_tokens),
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
chunks.emplace_back(std::move(chunk));
|
||||
}
|
||||
@@ -336,24 +394,36 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::move(tokens),
|
||||
{},
|
||||
nullptr, // image tokens
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
output->entries.emplace_back(std::move(chunk));
|
||||
|
||||
if (&parts.back() != &part) {
|
||||
// add image token to middle of 2 parts
|
||||
// only add image/audio tokens to middle of 2 parts
|
||||
// therefore, we skip handling image/audio if this is the last part
|
||||
if (&parts.back() == &part) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i_img >= n_bitmaps) {
|
||||
if (!bitmaps[i_bm]->is_audio) {
|
||||
// handle image
|
||||
|
||||
if (i_bm >= n_bitmaps) {
|
||||
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!ctx->has_vision) {
|
||||
LOG_ERR("%s: error: model does not support vision input\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// convert mtmd_bitmap to clip_image_u8
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmaps[i_img]->nx;
|
||||
img_u8->ny = bitmaps[i_img]->ny;
|
||||
img_u8->buf.resize(bitmaps[i_img]->data.size());
|
||||
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
|
||||
img_u8->nx = bitmaps[i_bm]->nx;
|
||||
img_u8->ny = bitmaps[i_bm]->ny;
|
||||
img_u8->buf.resize(bitmaps[i_bm]->data.size());
|
||||
std::memcpy(img_u8->buf.data(), bitmaps[i_bm]->data.data(), img_u8->nx * img_u8->ny * 3);
|
||||
|
||||
// preprocess image
|
||||
clip_image_f32_batch batch_f32;
|
||||
@@ -370,7 +440,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
|
||||
) {
|
||||
// split batch into chunks of single images
|
||||
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
|
||||
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_bm]->id);
|
||||
GGML_ASSERT(chunks.size() > 0);
|
||||
|
||||
auto ov_chunk = std::move(chunks.front());
|
||||
@@ -446,7 +516,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
image_tokens->ny = 1;
|
||||
}
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
image_tokens->id = bitmaps[i_img]->id; // optional
|
||||
image_tokens->id = bitmaps[i_bm]->id; // optional
|
||||
|
||||
LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
|
||||
LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
|
||||
@@ -454,23 +524,101 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
{}, // text tokens
|
||||
std::move(image_tokens),
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
output->entries.emplace_back(std::move(chunk));
|
||||
}
|
||||
|
||||
i_img++; // move to next image
|
||||
i_bm++; // move to next image
|
||||
continue;
|
||||
|
||||
} else {
|
||||
// handle audio
|
||||
|
||||
if (i_bm >= n_bitmaps) {
|
||||
LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!ctx->has_audio) {
|
||||
LOG_ERR("%s: error: model does not support audio input\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (bitmaps[i_bm]->data.size() == 0) {
|
||||
LOG_ERR("%s: error: empty audio data\n", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// preprocess audio
|
||||
GGML_ASSERT(ctx->w_filters.n_mel); // make sure we have filter preloaded
|
||||
std::vector<whisper_preprocessor::whisper_mel> mel_spec_chunks;
|
||||
const float * samples = (const float *)bitmaps[i_bm]->data.data();
|
||||
size_t n_samples = bitmaps[i_bm]->data.size() / sizeof(float);
|
||||
bool ok = whisper_preprocessor::preprocess_audio(samples, n_samples, ctx->w_filters, mel_spec_chunks);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess audio\n");
|
||||
return 2;
|
||||
}
|
||||
|
||||
// consider each mel_spec as a separate audio chunk
|
||||
// TODO: maybe support batching, but this may come with memory cost
|
||||
for (auto & mel_spec : mel_spec_chunks) {
|
||||
clip_image_f32_ptr mel_f32(clip_image_f32_init());
|
||||
mel_f32->nx = mel_spec.n_len;
|
||||
mel_f32->ny = mel_spec.n_mel;
|
||||
mel_f32->buf = std::move(mel_spec.data);
|
||||
size_t n_tokens = clip_n_output_tokens(ctx->ctx_clip, mel_f32.get());
|
||||
|
||||
clip_image_f32_batch batch_f32;
|
||||
batch_f32.is_audio = true;
|
||||
batch_f32.entries.push_back(std::move(mel_f32));
|
||||
|
||||
mtmd_audio_tokens_ptr audio_tokens(new mtmd_audio_tokens);
|
||||
audio_tokens->n_tokens = n_tokens;
|
||||
audio_tokens->batch_f32 = std::move(batch_f32);
|
||||
audio_tokens->id = bitmaps[i_bm]->id; // optional
|
||||
|
||||
LOG_DBG("audio_tokens->n_tokens = %d\n", audio_tokens->n_tokens);
|
||||
|
||||
mtmd_input_chunk chunk{
|
||||
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
||||
{}, // text tokens
|
||||
nullptr, // image tokens
|
||||
std::move(audio_tokens),
|
||||
};
|
||||
output->entries.emplace_back(std::move(chunk));
|
||||
}
|
||||
|
||||
i_bm++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
|
||||
if (image_tokens) {
|
||||
delete image_tokens;
|
||||
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
LOG_WRN("mtmd_encode_chunk has no effect for text chunks\n");
|
||||
return 0;
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
return mtmd_encode(ctx, chunk->tokens_image.get());
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
|
||||
ctx->image_embd_v.resize(chunk->tokens_audio->n_tokens * n_mmproj_embd);
|
||||
bool ok = clip_image_batch_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
&chunk->tokens_audio->batch_f32,
|
||||
ctx->image_embd_v.data());
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
||||
LOG_ERR("mtmd_encode_chunk: unknown chunk type %d\n", (int)chunk->type);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
|
||||
@@ -516,8 +664,12 @@ bool mtmd_decode_use_mrope(mtmd_context * ctx) {
|
||||
return ctx->use_mrope;
|
||||
}
|
||||
|
||||
void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
|
||||
mtmd_image_tokens_free(val);
|
||||
bool mtmd_support_vision(mtmd_context * ctx) {
|
||||
return ctx->has_vision;
|
||||
}
|
||||
|
||||
bool mtmd_support_audio(mtmd_context * ctx) {
|
||||
return ctx->has_audio;
|
||||
}
|
||||
|
||||
// these 2 helpers below use internal clip_image_u8_ptr,
|
||||
@@ -526,6 +678,15 @@ void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
|
||||
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
|
||||
|
||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
|
||||
if (audio_helpers::is_audio_file((const char *)buf, len)) {
|
||||
std::vector<float> pcmf32;
|
||||
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
|
||||
LOG_ERR("Unable to read WAV audio file from buffer\n");
|
||||
return nullptr;
|
||||
}
|
||||
return mtmd_bitmap_init_from_audio(pcmf32.size(), pcmf32.data());
|
||||
}
|
||||
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
bool ok = clip_image_load_from_bytes(buf, len, img_u8.get());
|
||||
if (!ok) {
|
||||
@@ -538,15 +699,26 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname) {
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
bool ok = clip_image_load_from_file(fname, img_u8.get());
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to load image %s\n", fname);
|
||||
std::vector<unsigned char> buf;
|
||||
FILE * f = fopen(fname, "rb");
|
||||
if (!f) {
|
||||
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
|
||||
return nullptr;
|
||||
}
|
||||
uint32_t nx, ny;
|
||||
unsigned char * data = clip_image_u8_get_data(img_u8.get(), &nx, &ny);
|
||||
return mtmd_bitmap_init(nx, ny, data);
|
||||
|
||||
fseek(f, 0, SEEK_END);
|
||||
long file_size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
buf.resize(file_size);
|
||||
|
||||
size_t n_read = fread(buf.data(), 1, file_size, f);
|
||||
fclose(f);
|
||||
if (n_read != (size_t)file_size) {
|
||||
LOG_ERR("Failed to read entire file %s", fname);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return mtmd_helper_bitmap_init_from_buf(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
//
|
||||
@@ -567,6 +739,18 @@ mtmd_bitmap * mtmd_bitmap_init(uint32_t nx,
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples,
|
||||
const float * data) {
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap;
|
||||
bitmap->nx = n_samples;
|
||||
bitmap->ny = 1;
|
||||
bitmap->is_audio = true;
|
||||
size_t data_size = n_samples * sizeof(float);
|
||||
bitmap->data.resize(data_size);
|
||||
std::memcpy(bitmap->data.data(), data, data_size);
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
uint32_t mtmd_bitmap_get_nx(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->nx;
|
||||
}
|
||||
@@ -579,6 +763,14 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->data.data();
|
||||
}
|
||||
|
||||
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->data.size();
|
||||
}
|
||||
|
||||
bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->is_audio;
|
||||
}
|
||||
|
||||
const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap) {
|
||||
return bitmap->id.c_str();
|
||||
}
|
||||
@@ -642,17 +834,56 @@ const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chu
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t mtmd_input_chunk_get_n_tokens(const mtmd_input_chunk * chunk) {
|
||||
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
return chunk->tokens_text.size();
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
return mtmd_image_tokens_get_n_tokens(chunk->tokens_image.get());
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
return chunk->tokens_audio->n_tokens;
|
||||
} else {
|
||||
GGML_ABORT("invalid chunk type");
|
||||
}
|
||||
}
|
||||
|
||||
llama_pos mtmd_input_chunk_get_n_pos(const mtmd_input_chunk * chunk) {
|
||||
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
return chunk->tokens_text.size();
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
return mtmd_image_tokens_get_n_pos(chunk->tokens_image.get());
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
return chunk->tokens_audio->n_tokens;
|
||||
} else {
|
||||
GGML_ABORT("invalid chunk type");
|
||||
}
|
||||
}
|
||||
|
||||
const char * mtmd_input_chunk_get_id(const mtmd_input_chunk * chunk) {
|
||||
if (chunk->type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
return chunk->tokens_image->id.c_str();
|
||||
} else if (chunk->type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
return chunk->tokens_audio->id.c_str();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mtmd_input_chunk * mtmd_input_chunk_copy(const mtmd_input_chunk * chunk) {
|
||||
mtmd_input_chunk * copy = new mtmd_input_chunk{
|
||||
chunk->type,
|
||||
chunk->tokens_text,
|
||||
mtmd_image_tokens_ptr(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
};
|
||||
if (chunk->tokens_image) {
|
||||
// copy the image tokens
|
||||
copy->tokens_image = mtmd_image_tokens_ptr(new mtmd_image_tokens());
|
||||
*copy->tokens_image = chunk->tokens_image->clone();
|
||||
}
|
||||
if (chunk->tokens_audio) {
|
||||
// copy the audio tokens
|
||||
copy->tokens_audio = mtmd_audio_tokens_ptr(new mtmd_audio_tokens());
|
||||
*copy->tokens_audio = chunk->tokens_audio->clone();
|
||||
}
|
||||
return copy;
|
||||
}
|
||||
|
||||
@@ -700,7 +931,8 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
mtmd_input_chunk chunk_text{
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
std::move(tokens_text),
|
||||
{},
|
||||
nullptr, // image tokens
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
chunks->entries.emplace_back(std::move(chunk_text));
|
||||
|
||||
@@ -712,8 +944,9 @@ mtmd_input_chunks * mtmd_test_create_input_chunks() {
|
||||
image_tokens->id = "image_1";
|
||||
mtmd_input_chunk chunk_image{
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
{},
|
||||
{}, // text tokens
|
||||
std::move(image_tokens),
|
||||
nullptr, // audio tokens
|
||||
};
|
||||
chunks->entries.emplace_back(std::move(chunk_image));
|
||||
|
||||
|
||||
+54
-21
@@ -39,6 +39,7 @@
|
||||
# define MTMD_API
|
||||
#endif
|
||||
|
||||
// deprecated marker, use mtmd_default_marker() instead
|
||||
#define MTMD_DEFAULT_IMAGE_MARKER "<__image__>"
|
||||
|
||||
#ifdef __cplusplus
|
||||
@@ -48,6 +49,7 @@ extern "C" {
|
||||
enum mtmd_input_chunk_type {
|
||||
MTMD_INPUT_CHUNK_TYPE_TEXT,
|
||||
MTMD_INPUT_CHUNK_TYPE_IMAGE,
|
||||
MTMD_INPUT_CHUNK_TYPE_AUDIO,
|
||||
};
|
||||
|
||||
// opaque types
|
||||
@@ -79,9 +81,12 @@ struct mtmd_context_params {
|
||||
bool print_timings;
|
||||
int n_threads;
|
||||
enum ggml_log_level verbosity;
|
||||
const char * image_marker;
|
||||
const char * image_marker; // deprecated, use media_marker instead
|
||||
const char * media_marker;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
MTMD_API struct mtmd_context_params mtmd_context_params_default(void);
|
||||
|
||||
// initialize the mtmd context
|
||||
@@ -98,18 +103,28 @@ MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
|
||||
// whether the current model use M-RoPE for llama_decode
|
||||
MTMD_API bool mtmd_decode_use_mrope(mtmd_context * ctx);
|
||||
|
||||
// whether the current model supports vision input
|
||||
MTMD_API bool mtmd_support_vision(mtmd_context * ctx);
|
||||
|
||||
// whether the current model supports audio input
|
||||
MTMD_API bool mtmd_support_audio(mtmd_context * ctx);
|
||||
|
||||
// mtmd_bitmap
|
||||
//
|
||||
// length of data must be nx * ny * 3
|
||||
// the data is in RGBRGBRGB... format
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx,
|
||||
uint32_t ny,
|
||||
const unsigned char * data);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||
MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap);
|
||||
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||
// if bitmap is image:
|
||||
// length of data must be nx * ny * 3
|
||||
// the data is in RGBRGBRGB... format
|
||||
// if bitmap is audio:
|
||||
// length of data must be n_samples * sizeof(float)
|
||||
// the data is in float format (PCM F32)
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data);
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap);
|
||||
MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap);
|
||||
MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap);
|
||||
MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap);
|
||||
MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap);
|
||||
MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||
// bitmap ID is optional, but useful for KV cache tracking
|
||||
// these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data()
|
||||
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
|
||||
@@ -132,6 +147,11 @@ MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chu
|
||||
MTMD_API enum mtmd_input_chunk_type mtmd_input_chunk_get_type (const mtmd_input_chunk * chunk);
|
||||
MTMD_API const llama_token * mtmd_input_chunk_get_tokens_text (const mtmd_input_chunk * chunk, size_t * n_tokens_output);
|
||||
MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd_input_chunk * chunk);
|
||||
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
|
||||
// returns nullptr for ID on text chunk
|
||||
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
|
||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
|
||||
|
||||
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
|
||||
// you can move the chunk ownership to your own code by copying it
|
||||
@@ -144,27 +164,28 @@ MTMD_API void mtmd_input_chunk_free(mtmd_input_chunk * chunk);
|
||||
//
|
||||
// the instance will be constructed via mtmd_tokenize()
|
||||
// it will be freed along with mtmd_input_chunk
|
||||
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
|
||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens);
|
||||
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
|
||||
|
||||
// tokenize an input text prompt and an image
|
||||
// the prompt must have the input image marker (default: "<__image__>") in it
|
||||
// the marker will be replaced with the image tokens
|
||||
// tokenize an input text prompt and a list of bitmaps (images/audio)
|
||||
// the prompt must have the input image marker (default: "<__media__>") in it
|
||||
// the default marker is defined by mtmd_default_marker()
|
||||
// the marker will be replaced with the image/audio chunk
|
||||
// for example:
|
||||
// "here is an image: <__image__>\ndescribe it in detail."
|
||||
// "here is an image: <__media__>\ndescribe it in detail."
|
||||
// this will gives 3 chunks:
|
||||
// 1. "here is an image: <start_of_image>"
|
||||
// 2. (image tokens)
|
||||
// 2. (image/audio tokens)
|
||||
// 3. "<end_of_image>\ndescribe it in detail."
|
||||
// number of bitmaps must be equal to the number of image markers in the prompt
|
||||
// number of bitmaps must be equal to the number of markers in the prompt
|
||||
// this function is thread-safe (shared ctx)
|
||||
// return values:
|
||||
// 0 on success
|
||||
// 1 on number of images not matching the number of markers
|
||||
// 1 on number of bitmaps not matching the number of markers
|
||||
// 2 on image preprocessing error
|
||||
MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
mtmd_input_chunks * output,
|
||||
@@ -173,10 +194,17 @@ MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
size_t n_bitmaps);
|
||||
|
||||
// returns 0 on success
|
||||
// TODO: deprecate
|
||||
MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
|
||||
const mtmd_image_tokens * image_tokens);
|
||||
|
||||
// returns 0 on success
|
||||
MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx,
|
||||
const mtmd_input_chunk * chunk);
|
||||
|
||||
// get output embeddings from the last encode pass
|
||||
// the reading size (in bytes) is equal to:
|
||||
// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float)
|
||||
MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
|
||||
/////////////////////////////////////////
|
||||
@@ -189,12 +217,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
|
||||
//
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||
// returns nullptr on failure
|
||||
// this function is thread-safe
|
||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a buffer containing a file
|
||||
// the file content must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.)
|
||||
// supported formats:
|
||||
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
|
||||
// audio: formats supported by miniaudio: wav, mp3, flac
|
||||
// note: audio files will be auto-detected based on magic bytes
|
||||
// returns nullptr on failure
|
||||
// this function is thread-safe
|
||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);
|
||||
@@ -293,6 +325,7 @@ struct bitmap {
|
||||
uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); }
|
||||
uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); }
|
||||
const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); }
|
||||
size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); }
|
||||
std::string id() { return mtmd_bitmap_get_id(ptr.get()); }
|
||||
void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); }
|
||||
};
|
||||
|
||||
@@ -111,7 +111,7 @@ static std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
|
||||
@@ -173,7 +173,8 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--no-slots` | disables slots monitoring endpoint<br/>(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) |
|
||||
| `--slot-save-path PATH` | path to save slot kv cache (default: disabled) |
|
||||
| `--jinja` | use jinja template for chat (default: disabled)<br/>(env: LLAMA_ARG_JINJA) |
|
||||
| `--reasoning-format FORMAT` | reasoning format (default: deepseek; allowed values: deepseek, none)<br/>controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).<br/>only supported for non-streamed responses<br/>(env: LLAMA_ARG_THINK) |
|
||||
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)<br/>(default: deepseek)<br/>(env: LLAMA_ARG_THINK) |
|
||||
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
|
||||
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
|
||||
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
|
||||
| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |
|
||||
|
||||
Binary file not shown.
+198
-164
@@ -1,3 +1,4 @@
|
||||
#include "chat.h"
|
||||
#include "utils.hpp"
|
||||
|
||||
#include "arg.h"
|
||||
@@ -114,11 +115,11 @@ struct slot_params {
|
||||
struct common_params_speculative speculative;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_syntax oaicompat_chat_syntax;
|
||||
|
||||
json to_json() const {
|
||||
std::vector<std::string> samplers;
|
||||
@@ -176,7 +177,10 @@ struct slot_params {
|
||||
{"grammar_lazy", sampling.grammar_lazy},
|
||||
{"grammar_triggers", grammar_triggers},
|
||||
{"preserved_tokens", sampling.preserved_tokens},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
||||
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
||||
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
|
||||
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
||||
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
||||
{"samplers", samplers},
|
||||
{"speculative.n_max", speculative.n_max},
|
||||
{"speculative.n_min", speculative.n_min},
|
||||
@@ -352,11 +356,14 @@ struct server_task {
|
||||
{
|
||||
auto it = data.find("chat_format");
|
||||
if (it != data.end()) {
|
||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
||||
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
|
||||
} else {
|
||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
||||
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||
}
|
||||
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
||||
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
|
||||
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||
}
|
||||
|
||||
{
|
||||
@@ -396,7 +403,14 @@ struct server_task {
|
||||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||
}
|
||||
} else {
|
||||
params.sampling.grammar_triggers.push_back(std::move(ct.value));
|
||||
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
||||
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
||||
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
||||
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
||||
} else {
|
||||
throw std::runtime_error("Unknown grammar trigger type");
|
||||
}
|
||||
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
slot_params generation_params;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
common_chat_msg oaicompat_msg;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
json to_json_oaicompat_chat() {
|
||||
std::string finish_reason = "length";
|
||||
common_chat_msg msg;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
SRV_DBG("Parsing chat message: %s\n", content.c_str());
|
||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
if (!oaicompat_msg.empty()) {
|
||||
msg = oaicompat_msg;
|
||||
} else {
|
||||
msg.role = "assistant";
|
||||
msg.content = content;
|
||||
}
|
||||
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
message["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
} else {
|
||||
message["content"] = msg.content;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tc : msg.tool_calls) {
|
||||
tool_calls.push_back({
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
}},
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
}
|
||||
message["tool_calls"] = tool_calls;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"message", message},
|
||||
{"message", msg.to_json_oaicompat<json>()},
|
||||
};
|
||||
|
||||
if (!stream && probs_output.size() > 0) {
|
||||
@@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||
}
|
||||
|
||||
json choice = json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}
|
||||
};
|
||||
json deltas = json::array();
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", json::array({choice})},
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
@@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
};
|
||||
});
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
deltas.back().push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
// extra fields for debugging purposes
|
||||
if (verbose) {
|
||||
ret["__verbose"] = to_json_non_oaicompat();
|
||||
if (verbose && !deltas.empty()) {
|
||||
deltas.front()["__verbose"] = to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
return ret;
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
result_timings timings;
|
||||
|
||||
// OAI-compat fields
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
bool verbose = false;
|
||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
std::string oaicompat_model;
|
||||
std::string oaicompat_cmpl_id;
|
||||
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
@@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
std::time_t t = std::time(0);
|
||||
json choices;
|
||||
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{{"role", "assistant"}}}}});
|
||||
} else {
|
||||
// We have to send this as two updates to conform to openai behavior
|
||||
// initial_ret is the role message for stream=True
|
||||
json initial_ret = json{{"choices", json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json{
|
||||
{"role", "assistant"},
|
||||
{"content", ""}
|
||||
}}}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
json second_ret = json{
|
||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", json {
|
||||
{"content", content}}}
|
||||
}})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}};
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
second_ret["choices"][0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
second_ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return std::vector<json>({initial_ret, second_ret});
|
||||
}
|
||||
} else {
|
||||
choices = json::array({json{
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta",
|
||||
json {
|
||||
{"content", content},
|
||||
}},
|
||||
}});
|
||||
}
|
||||
|
||||
GGML_ASSERT(choices.size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
choices[0]["logprobs"] = json{
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"}
|
||||
std::vector<json> deltas;
|
||||
auto add_delta = [&](const json & delta) {
|
||||
deltas.push_back({
|
||||
{"choices", json::array({
|
||||
json {
|
||||
{"finish_reason", nullptr},
|
||||
{"index", 0},
|
||||
{"delta", delta},
|
||||
},
|
||||
})},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"system_fingerprint", build_info},
|
||||
{"object", "chat.completion.chunk"},
|
||||
});
|
||||
};
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
// We have to send an initial update to conform to openai behavior
|
||||
if (first) {
|
||||
add_delta({
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
});
|
||||
}
|
||||
|
||||
return std::vector<json>({ret});
|
||||
for (const auto & diff : oaicompat_msg_diffs) {
|
||||
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||
}
|
||||
|
||||
if (!deltas.empty()) {
|
||||
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||
|
||||
if (prob_output.probs.size() > 0) {
|
||||
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
|
||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||
};
|
||||
}
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
|
||||
}
|
||||
}
|
||||
|
||||
return deltas;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1293,6 +1266,7 @@ struct server_slot {
|
||||
|
||||
std::string generated_text;
|
||||
llama_tokens generated_tokens;
|
||||
common_chat_msg chat_msg;
|
||||
|
||||
server_tokens cache_tokens;
|
||||
|
||||
@@ -1313,6 +1287,7 @@ struct server_slot {
|
||||
llama_token sampled;
|
||||
|
||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
std::vector<std::string> generated_tool_call_ids;
|
||||
|
||||
// stats
|
||||
size_t n_sent_text = 0; // number of sent text character
|
||||
@@ -1342,9 +1317,13 @@ struct server_slot {
|
||||
n_past = 0;
|
||||
n_sent_text = 0;
|
||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
chat_msg = {};
|
||||
json_schema = json();
|
||||
generated_tool_call_ids.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
@@ -1424,6 +1403,21 @@ struct server_slot {
|
||||
return timings;
|
||||
}
|
||||
|
||||
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||
auto previous_msg = chat_msg;
|
||||
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||
auto new_msg = common_chat_parse(
|
||||
generated_text,
|
||||
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||
params.oaicompat_chat_syntax);
|
||||
if (!new_msg.empty()) {
|
||||
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
||||
chat_msg = new_msg;
|
||||
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||
}
|
||||
return chat_msg;
|
||||
}
|
||||
|
||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||
size_t stop_pos = std::string::npos;
|
||||
|
||||
@@ -1891,6 +1885,7 @@ struct server_context {
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
common_chat_templates_ptr chat_templates;
|
||||
oaicompat_parser_options oai_parser_opt;
|
||||
|
||||
~server_context() {
|
||||
mtmd_free(mctx);
|
||||
@@ -2086,6 +2081,16 @@ struct server_context {
|
||||
}
|
||||
|
||||
metrics.init();
|
||||
|
||||
oai_parser_opt = {
|
||||
/* use_jinja */ params_base.use_jinja,
|
||||
/* prefill_assistant */ params_base.prefill_assistant,
|
||||
/* reasoning_format */ params_base.reasoning_format,
|
||||
/* common_chat_templates */ chat_templates.get(),
|
||||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* enable_thinking */ params_base.reasoning_budget != 0,
|
||||
};
|
||||
}
|
||||
|
||||
server_slot * get_slot_by_id(int id) {
|
||||
@@ -2465,10 +2470,12 @@ struct server_context {
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
|
||||
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
@@ -2489,7 +2496,7 @@ struct server_context {
|
||||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.index;
|
||||
res->content = std::move(slot.generated_text);
|
||||
res->content = slot.generated_text;
|
||||
res->tokens = std::move(slot.generated_tokens);
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
||||
@@ -2509,7 +2516,8 @@ struct server_context {
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
||||
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||
|
||||
// populate res.probs_output
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||
@@ -3341,6 +3349,37 @@ struct server_context {
|
||||
common_set_adapter_lora(ctx, slot_batched->lora);
|
||||
}
|
||||
|
||||
const bool do_encode = (params_base.embedding || params_base.reranking);
|
||||
|
||||
// pad the batch so that batch.n_tokens >= n_slots
|
||||
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
|
||||
if (do_encode) {
|
||||
const int n_slots = slots.size();
|
||||
|
||||
if (batch.n_tokens < n_slots) {
|
||||
std::set<llama_seq_id> seq_ids;
|
||||
for (int j = 0; j < batch.n_tokens; ++j) {
|
||||
seq_ids.insert(batch.seq_id[j][0]);
|
||||
}
|
||||
|
||||
// find unused sequence id
|
||||
llama_seq_id seq_id = -1;
|
||||
for (int i = 0; i < n_slots; ++i) {
|
||||
if (seq_ids.find(i) == seq_ids.end()) {
|
||||
seq_id = i;
|
||||
}
|
||||
}
|
||||
|
||||
const int n_add = n_slots - batch.n_tokens;
|
||||
|
||||
SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
|
||||
|
||||
for (int j = 0; j < n_add; ++j) {
|
||||
common_batch_add(batch, 0, j, { seq_id }, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||
@@ -3357,7 +3396,7 @@ struct server_context {
|
||||
|
||||
int ret = 0;
|
||||
|
||||
if (params_base.embedding || params_base.reranking) {
|
||||
if (do_encode) {
|
||||
ret = llama_encode(ctx, batch_view);
|
||||
} else {
|
||||
ret = llama_decode(ctx, batch_view);
|
||||
@@ -4061,7 +4100,10 @@ int main(int argc, char ** argv) {
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model.path },
|
||||
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
|
||||
{ "modalities", json{
|
||||
{"vision", ctx_server.oai_parser_opt.allow_image},
|
||||
{"audio", ctx_server.oai_parser_opt.allow_audio},
|
||||
} },
|
||||
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
||||
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
||||
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
||||
@@ -4152,10 +4194,10 @@ int main(int argc, char ** argv) {
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
|
||||
if (!bmp.ptr) {
|
||||
throw std::runtime_error("Failed to load image");
|
||||
throw std::runtime_error("Failed to load image or audio file");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
}
|
||||
@@ -4387,7 +4429,7 @@ int main(int argc, char ** argv) {
|
||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
LOG_DBG("request: %s\n", req.body.c_str());
|
||||
if (ctx_server.params_base.embedding) {
|
||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
@@ -4396,13 +4438,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files;
|
||||
json data = oaicompat_completion_params_parse(
|
||||
json data = oaicompat_chat_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.prefill_assistant,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
|
||||
handle_completions_impl(
|
||||
@@ -4415,16 +4453,12 @@ int main(int argc, char ** argv) {
|
||||
};
|
||||
|
||||
// same with handle_chat_completions, but without inference part
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
std::vector<raw_buffer> files; // dummy, unused
|
||||
json data = oaicompat_completion_params_parse(
|
||||
json data = oaicompat_chat_params_parse(
|
||||
body,
|
||||
params.use_jinja,
|
||||
params.prefill_assistant,
|
||||
params.reasoning_format,
|
||||
ctx_server.chat_templates.get(),
|
||||
ctx_server.mctx,
|
||||
ctx_server.oai_parser_opt,
|
||||
files);
|
||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||
};
|
||||
|
||||
@@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||
choice = data["choices"][0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice["delta"]["content"] == ""
|
||||
assert choice["delta"]["content"] is None
|
||||
assert choice["delta"]["role"] == "assistant"
|
||||
else:
|
||||
assert "role" not in choice["delta"]
|
||||
@@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
content += choice["delta"]["content"] or ''
|
||||
|
||||
|
||||
def test_chat_completion_with_openai_library():
|
||||
@@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
|
||||
for i, data in enumerate(res):
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert data["choices"][0]["delta"]["content"] == ""
|
||||
assert data["choices"][0]["delta"]["content"] is None
|
||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||
else:
|
||||
assert "role" not in data["choices"][0]["delta"]
|
||||
assert "timings" in data
|
||||
@@ -311,7 +312,7 @@ def test_logprobs_stream():
|
||||
choice = data.choices[0]
|
||||
if i == 0:
|
||||
# Check first role message for stream=True
|
||||
assert choice.delta.content == ""
|
||||
assert choice.delta.content is None
|
||||
assert choice.delta.role == "assistant"
|
||||
else:
|
||||
assert choice.delta.role is None
|
||||
|
||||
@@ -25,6 +25,40 @@ def create_server():
|
||||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "<think>\n"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "<think>\n</think>"),
|
||||
|
||||
("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"),
|
||||
("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n<think>\n\n</think>\n\n"),
|
||||
|
||||
("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n<think>\n"),
|
||||
("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n<think>\n</think>"),
|
||||
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"),
|
||||
("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"),
|
||||
])
|
||||
def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.reasoning_budget = reasoning_budget
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
@@ -47,3 +81,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_generation_prompt", [False, True])
|
||||
@pytest.mark.parametrize("template_name,expected_generation_prompt", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"),
|
||||
])
|
||||
def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
if add_generation_prompt:
|
||||
assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
else:
|
||||
assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})"
|
||||
|
||||
@@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
from utils import *
|
||||
from enum import Enum
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
@@ -20,7 +21,11 @@ def create_server():
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2-tool-call"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
class CompletionMode(Enum):
|
||||
NORMAL = "normal"
|
||||
STREAMED = "streamed"
|
||||
|
||||
TEST_TOOL = {
|
||||
"type":"function",
|
||||
@@ -73,9 +78,8 @@ WEATHER_TOOL = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||
"parallel_tool_calls": False,
|
||||
**kwargs,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||
assert expected_function_name == tool_call["function"]["name"]
|
||||
actual_arguments = tool_call["function"]["arguments"]
|
||||
@@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 1024
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
@@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||
|
||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
|
||||
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||
|
||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
|
||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
@@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
|
||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
@@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||
@@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
||||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
@@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"top_p": 1.0,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
@@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
||||
|
||||
|
||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
@@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
|
||||
"tool_choice": tool_choice,
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||
@@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||
])
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
@@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||
@@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
||||
|
||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||
])
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
@@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_weather(server, max_tokens=n_predict)
|
||||
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_weather(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||
@@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
||||
"tools": [WEATHER_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||
location = actual_arguments["location"]
|
||||
@@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||
@@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
||||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
server.n_predict = n_predict
|
||||
@@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
do_test_calc_result(server, result_override, n_predict)
|
||||
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||
|
||||
|
||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||
@@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
||||
],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||
content = choice["message"].get("content")
|
||||
@@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||
])
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
server.n_slots = 1
|
||||
server.reasoning_format = reasoning_format
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192 * 2
|
||||
@@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
elif isinstance(template_override, str):
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"max_tokens": n_predict,
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||
]
|
||||
],
|
||||
"stream": stream == CompletionMode.STREAMED,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
content = choice["message"].get("content")
|
||||
@@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||
|
||||
@@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||
])
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||
global server
|
||||
n_predict = 512 # High because of DeepSeek R1
|
||||
server.n_slots = 1
|
||||
server.jinja = True
|
||||
server.n_ctx = 8192
|
||||
server.n_predict = n_predict
|
||||
@@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
|
||||
server.chat_template = template_override
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
do_test_hello_world(server, max_tokens=n_predict)
|
||||
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||
|
||||
|
||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
||||
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a tool-calling agent."},
|
||||
{"role": "user", "content": "say hello world with python"},
|
||||
@@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||
"tools": [PYTHON_TOOL],
|
||||
**kwargs,
|
||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
choice = body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||
code = actual_arguments["code"]
|
||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
|
||||
|
||||
@@ -30,6 +30,7 @@ def create_server():
|
||||
("What is this:\n", "malformed", False, None),
|
||||
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
||||
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
||||
# TODO @ngxson : test with multiple images, no images and with audio
|
||||
]
|
||||
)
|
||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
|
||||
@@ -84,7 +84,8 @@ class ServerProcess:
|
||||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none'] | None = None
|
||||
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
|
||||
reasoning_budget: int | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
@@ -191,6 +192,8 @@ class ServerProcess:
|
||||
server_args.append("--jinja")
|
||||
if self.reasoning_format is not None:
|
||||
server_args.extend(("--reasoning-format", self.reasoning_format))
|
||||
if self.reasoning_budget is not None:
|
||||
server_args.extend(("--reasoning-budget", self.reasoning_budget))
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
@@ -294,6 +297,77 @@ class ServerProcess:
|
||||
print("Partial response from server", json.dumps(data, indent=2))
|
||||
yield data
|
||||
|
||||
def make_any_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict:
|
||||
stream = data.get('stream', False)
|
||||
if stream:
|
||||
content: list[str] = []
|
||||
tool_calls: list[dict] = []
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
content_parts = 0
|
||||
tool_call_parts = 0
|
||||
arguments_parts = 0
|
||||
|
||||
for chunk in self.make_stream_request(method, path, data, headers):
|
||||
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||
choice = chunk['choices'][0]
|
||||
if choice['delta'].get('content') is not None:
|
||||
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||
content.append(choice['delta']['content'])
|
||||
content_parts += 1
|
||||
if choice['delta'].get('finish_reason') is not None:
|
||||
finish_reason = choice['delta']['finish_reason']
|
||||
for tc in choice['delta'].get('tool_calls', []):
|
||||
if 'function' not in tc:
|
||||
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||
if tc['index'] >= len(tool_calls):
|
||||
tool_calls.append(dict(
|
||||
id="",
|
||||
type="function",
|
||||
function=dict(
|
||||
name="",
|
||||
arguments="",
|
||||
)
|
||||
))
|
||||
tool_call = tool_calls[tc['index']]
|
||||
if tc.get('id') is not None:
|
||||
tool_call['id'] = tc['id']
|
||||
fct = tc['function']
|
||||
if fct.get('name') is not None:
|
||||
tool_call['function']['name'] = fct['name']
|
||||
if fct.get('arguments') is not None:
|
||||
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
|
||||
tool_call['function']['arguments'] += fct['arguments']
|
||||
|
||||
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||
result = dict(
|
||||
choices=[
|
||||
dict(
|
||||
index=0,
|
||||
finish_reason=finish_reason,
|
||||
message=dict(
|
||||
role='assistant',
|
||||
content=''.join(content) if content else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
print("Final response from server", json.dumps(result, indent=2))
|
||||
return result
|
||||
else:
|
||||
response = self.make_request(method, path, data, headers, timeout=timeout)
|
||||
assert response.status_code == 200, f"Server returned error: {response.status_code}"
|
||||
return response.body
|
||||
|
||||
|
||||
|
||||
server_instances: Set[ServerProcess] = set()
|
||||
|
||||
|
||||
+84
-73
@@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
|
||||
// other common utils
|
||||
//
|
||||
|
||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
|
||||
if (!text.empty() && !stop.empty()) {
|
||||
const char text_last_char = text.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
||||
if (ends_with(text, current_partial)) {
|
||||
return text.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
// TODO: reuse llama_detokenize
|
||||
template <class Iter>
|
||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||
@@ -536,6 +516,7 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||
// OAI utils
|
||||
//
|
||||
|
||||
// used by /completions endpoint
|
||||
static json oaicompat_completion_params_parse(const json & body) {
|
||||
json llama_params;
|
||||
|
||||
@@ -580,31 +561,35 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json oaicompat_completion_params_parse(
|
||||
struct oaicompat_parser_options {
|
||||
bool use_jinja;
|
||||
bool prefill_assistant;
|
||||
common_reasoning_format reasoning_format;
|
||||
common_chat_templates * tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool enable_thinking = true;
|
||||
};
|
||||
|
||||
// used by /chat/completions endpoint
|
||||
static json oaicompat_chat_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
bool use_jinja,
|
||||
bool prefill_assistant,
|
||||
common_reasoning_format reasoning_format,
|
||||
const struct common_chat_templates * tmpls,
|
||||
bool allow_non_text,
|
||||
const oaicompat_parser_options & opt,
|
||||
std::vector<raw_buffer> & out_files)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
auto stream = json_value(body, "stream", false);
|
||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||
|
||||
if (tools.is_array() && !tools.empty()) {
|
||||
if (stream) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
if (!use_jinja) {
|
||||
if (!opt.use_jinja) {
|
||||
if (has_tools) {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
if (!use_jinja) {
|
||||
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
||||
throw std::runtime_error("Unsupported param: tool_choice");
|
||||
if (tool_choice != "auto") {
|
||||
throw std::runtime_error("tool_choice param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,12 +652,12 @@ static json oaicompat_completion_params_parse(
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
if (type == "image_url") {
|
||||
if (!allow_non_text) {
|
||||
throw std::runtime_error("image input is not supported by this server");
|
||||
if (!opt.allow_image) {
|
||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
@@ -710,8 +695,31 @@ static json oaicompat_completion_params_parse(
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["text"] = MTMD_DEFAULT_IMAGE_MARKER;
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("image_url");
|
||||
|
||||
} else if (type == "input_audio") {
|
||||
if (!opt.allow_audio) {
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string data = json_value(input_audio, "data", std::string());
|
||||
std::string format = json_value(input_audio, "format", std::string());
|
||||
// while we also support flac, we don't allow it here so we matches the OAI spec
|
||||
if (format != "wav" && format != "mp3") {
|
||||
throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'");
|
||||
}
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("input_audio");
|
||||
|
||||
} else if (type != "text") {
|
||||
throw std::runtime_error("unsupported content[].type");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -719,21 +727,21 @@ static json oaicompat_completion_params_parse(
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||
inputs.grammar = grammar;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.use_jinja = use_jinja;
|
||||
inputs.use_jinja = opt.use_jinja;
|
||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
||||
inputs.reasoning_format = opt.reasoning_format;
|
||||
inputs.enable_thinking = opt.enable_thinking;
|
||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||
}
|
||||
|
||||
// if the assistant message appears at the end of list, we do not add end-of-turn token
|
||||
// for ex. this can be useful to modify the reasoning process in reasoning models
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant;
|
||||
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant;
|
||||
common_chat_msg last_message;
|
||||
if (prefill_assistant_message) {
|
||||
last_message = inputs.messages.back();
|
||||
@@ -744,12 +752,13 @@ static json oaicompat_completion_params_parse(
|
||||
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
||||
}
|
||||
|
||||
inputs.extract_reasoning = false;
|
||||
/* TODO: test this properly */
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
inputs.add_generation_prompt = true;
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
|
||||
|
||||
/* Append assistant prefilled message */
|
||||
if (prefill_assistant_message) {
|
||||
@@ -769,6 +778,7 @@ static json oaicompat_completion_params_parse(
|
||||
}
|
||||
llama_params["grammar_triggers"] = grammar_triggers;
|
||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||
for (const auto & stop : chat_params.additional_stops) {
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
@@ -782,6 +792,9 @@ static json oaicompat_completion_params_parse(
|
||||
// Handle "logprobs" field
|
||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||
if (json_value(body, "logprobs", false)) {
|
||||
if (has_tools && stream) {
|
||||
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||
}
|
||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||
@@ -1040,7 +1053,7 @@ struct server_tokens {
|
||||
private: // disallow accessing these members directly, risking out-of-sync
|
||||
|
||||
// map a **start** position in tokens to the image chunk
|
||||
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_image;
|
||||
std::unordered_map<llama_pos, mtmd::input_chunk_ptr> map_pos_to_media;
|
||||
|
||||
// list of tokens
|
||||
// it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token
|
||||
@@ -1051,7 +1064,7 @@ private: // disallow accessing these members directly, risking out-of-sync
|
||||
// for ex. with input of 5 text tokens and 2 images:
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
||||
// pos 0 1 2 3 4 5 6 7 8 9
|
||||
// map_pos_to_image will contain: {5, img0}, {8, img1}
|
||||
// map_pos_to_media will contain: {5, img0}, {8, img1}
|
||||
|
||||
public:
|
||||
server_tokens() = default;
|
||||
@@ -1090,15 +1103,15 @@ public:
|
||||
}
|
||||
oss << "\n";
|
||||
oss << "image pos: ";
|
||||
for (const auto & it : map_pos_to_image) {
|
||||
for (const auto & it : map_pos_to_media) {
|
||||
oss << it.first << ", ";
|
||||
}
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const {
|
||||
auto it = map_pos_to_image.find(pos);
|
||||
if (it != map_pos_to_image.end()) {
|
||||
auto it = map_pos_to_media.find(pos);
|
||||
if (it != map_pos_to_media.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
throw std::runtime_error("Chunk not found");
|
||||
@@ -1115,16 +1128,15 @@ public:
|
||||
// will create a copy of the chunk if it contains non-text data
|
||||
void push_back(const mtmd_input_chunk * chunk) {
|
||||
auto type = mtmd_input_chunk_get_type(chunk);
|
||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
|
||||
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
|
||||
GGML_ASSERT(has_mtmd);
|
||||
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk);
|
||||
const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
|
||||
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
|
||||
llama_pos start_pos = tokens.size();
|
||||
for (int i = 0; i < n_pos; ++i) {
|
||||
tokens.emplace_back(LLAMA_TOKEN_NULL);
|
||||
}
|
||||
mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk));
|
||||
map_pos_to_image[start_pos] = std::move(new_chunk);
|
||||
map_pos_to_media[start_pos] = std::move(new_chunk);
|
||||
} else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
size_t n_tokens;
|
||||
auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
|
||||
@@ -1169,6 +1181,9 @@ public:
|
||||
void keep_first(size_t n) {
|
||||
GGML_ASSERT(n <= tokens.size());
|
||||
if (has_mtmd) {
|
||||
if (n == tokens.size()) {
|
||||
return; // nothing to do
|
||||
}
|
||||
// we throw an error if we try to remove a token in the middle of an image
|
||||
// for ex. with input of 5 text tokens and 2 images:
|
||||
// [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1]
|
||||
@@ -1183,10 +1198,10 @@ public:
|
||||
}
|
||||
}
|
||||
// remove all image chunks that are not used anymore
|
||||
for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) {
|
||||
for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) {
|
||||
llama_pos pos = it->first;
|
||||
if (pos >= (llama_pos)n) {
|
||||
it = map_pos_to_image.erase(it);
|
||||
it = map_pos_to_media.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
@@ -1217,14 +1232,12 @@ public:
|
||||
const auto & a_chunk = find_chunk(i);
|
||||
const auto & b_chunk = b.find_chunk(i);
|
||||
GGML_ASSERT(a_chunk && b_chunk);
|
||||
const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get());
|
||||
const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get());
|
||||
std::string ai_id = mtmd_image_tokens_get_id(a_img);
|
||||
std::string bi_id = mtmd_image_tokens_get_id(b_img);
|
||||
size_t a_pos = mtmd_image_tokens_get_n_pos(a_img);
|
||||
size_t b_pos = mtmd_image_tokens_get_n_pos(b_img);
|
||||
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
|
||||
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
|
||||
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
|
||||
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
|
||||
if (ai_id == bi_id && a_pos == b_pos) {
|
||||
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
|
||||
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
|
||||
i += a_pos - 1; // will be +1 by the for loop
|
||||
continue;
|
||||
} else {
|
||||
@@ -1250,8 +1263,7 @@ public:
|
||||
if (t == LLAMA_TOKEN_NULL) {
|
||||
try {
|
||||
const auto & chunk = find_chunk(i);
|
||||
const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
|
||||
size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens);
|
||||
size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
i += n_pos - 1; // will be +1 by the for loop
|
||||
} catch (const std::exception & e) {
|
||||
return false;
|
||||
@@ -1270,22 +1282,21 @@ public:
|
||||
llama_pos n_past,
|
||||
int32_t seq_id,
|
||||
llama_pos & n_pos_out) {
|
||||
auto it = map_pos_to_image.find(n_past);
|
||||
if (it == map_pos_to_image.end()) {
|
||||
throw std::runtime_error("Chunk not found");
|
||||
}
|
||||
SRV_INF("%s\n", "processing image...");
|
||||
auto & chunk = find_chunk(n_past);
|
||||
const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
|
||||
? "image" : "audio";
|
||||
SRV_INF("processing %s...\n", name);
|
||||
int32_t n_batch = llama_n_batch(ctx);
|
||||
int64_t t0 = ggml_time_ms();
|
||||
llama_pos new_n_past = n_past;
|
||||
int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx,
|
||||
it->second.get(), // chunk
|
||||
chunk.get(),
|
||||
n_past,
|
||||
seq_id,
|
||||
n_batch,
|
||||
true, // logits last
|
||||
&new_n_past);
|
||||
SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0);
|
||||
if (result != 0) {
|
||||
LOG_ERR("mtmd_helper_eval failed with status %d", result);
|
||||
n_pos_out = n_past;
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import { DocumentTextIcon, XMarkIcon } from '@heroicons/react/24/outline';
|
||||
import {
|
||||
DocumentTextIcon,
|
||||
SpeakerWaveIcon,
|
||||
XMarkIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import { useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
@@ -66,7 +70,11 @@ export default function ChatInputExtraContextItem({
|
||||
className="w-14 h-14 flex items-center justify-center"
|
||||
aria-description="Document icon"
|
||||
>
|
||||
<DocumentTextIcon className="h-8 w-14 text-base-content/50" />
|
||||
{item.type === 'audioFile' ? (
|
||||
<SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
|
||||
) : (
|
||||
<DocumentTextIcon className="h-8 w-8 text-gray-500" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="text-xs pr-4">
|
||||
@@ -98,6 +106,19 @@ export default function ChatInputExtraContextItem({
|
||||
src={showingItem.base64Url}
|
||||
alt={`Preview image for ${showingItem.name}`}
|
||||
/>
|
||||
) : showingItem.type === 'audioFile' ? (
|
||||
<audio
|
||||
controls
|
||||
className="w-full"
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
>
|
||||
<source
|
||||
src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
|
||||
type={showingItem.mimeType}
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
/>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<pre className="whitespace-pre-wrap break-words text-sm">
|
||||
|
||||
@@ -278,6 +278,13 @@ export default function ChatScreen() {
|
||||
|
||||
function ServerInfo() {
|
||||
const { serverProps } = useAppContext();
|
||||
const modalities = [];
|
||||
if (serverProps?.modalities?.audio) {
|
||||
modalities.push('audio');
|
||||
}
|
||||
if (serverProps?.modalities?.vision) {
|
||||
modalities.push('vision');
|
||||
}
|
||||
return (
|
||||
<div
|
||||
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
|
||||
@@ -291,6 +298,13 @@ function ServerInfo() {
|
||||
<br />
|
||||
<b>Build</b>: {serverProps?.build_info}
|
||||
<br />
|
||||
{modalities.length > 0 ? (
|
||||
<>
|
||||
<b>Supported modalities:</b> {modalities.join(', ')}
|
||||
</>
|
||||
) : (
|
||||
''
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -11,6 +11,7 @@ pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
|
||||
// This file handles uploading extra context items (a.k.a files)
|
||||
// It allows processing these kinds of files:
|
||||
// - image files (converted to base64)
|
||||
// - audio files (converted to base64)
|
||||
// - text files (including code files)
|
||||
// - pdf (converted to text)
|
||||
|
||||
@@ -41,96 +42,76 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
|
||||
const isSupportVision = serverProps?.modalities?.vision;
|
||||
|
||||
const onFileAdded = (files: File[]) => {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
console.debug({ mimeType, file });
|
||||
if (file.size > 10 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 10MB.');
|
||||
break;
|
||||
}
|
||||
const onFileAdded = async (files: File[]) => {
|
||||
try {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!isSupportVision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
// this limit is only to prevent accidental uploads of huge files
|
||||
// it can potentially crashes the browser because we read the file as base64
|
||||
if (file.size > 500 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 500MB.');
|
||||
break;
|
||||
}
|
||||
const reader = new FileReader();
|
||||
reader.onload = async (event) => {
|
||||
if (event.target?.result) {
|
||||
let base64Url = event.target.result as string;
|
||||
|
||||
if (mimeType === 'image/svg+xml') {
|
||||
// Convert SVG to PNG
|
||||
base64Url = await svgBase64UrlToPngDataURL(base64Url);
|
||||
}
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!isSupportVision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
break;
|
||||
}
|
||||
|
||||
addItems([
|
||||
{
|
||||
let base64Url = await getFileAsBase64(file);
|
||||
if (mimeType === 'image/svg+xml') {
|
||||
// Convert SVG to PNG
|
||||
base64Url = await svgBase64UrlToPngDataURL(base64Url);
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('video/')) {
|
||||
toast.error('Video files are not supported yet.');
|
||||
break;
|
||||
} else if (mimeType.startsWith('audio/')) {
|
||||
if (!/mpeg|wav/.test(mimeType)) {
|
||||
toast.error('Only mp3 and wav audio files are supported.');
|
||||
break;
|
||||
}
|
||||
|
||||
// plain base64, not a data URL
|
||||
const base64Data = await getFileAsBase64(file, false);
|
||||
addItems([
|
||||
{
|
||||
type: 'audioFile',
|
||||
name: file.name,
|
||||
mimeType,
|
||||
base64Data,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('application/pdf')) {
|
||||
if (config.pdfAsImage && !isSupportVision) {
|
||||
toast(
|
||||
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if (config.pdfAsImage && isSupportVision) {
|
||||
// Convert PDF to images
|
||||
const base64Urls = await convertPDFToImage(file);
|
||||
addItems(
|
||||
base64Urls.map((base64Url) => ({
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
},
|
||||
]);
|
||||
}
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
} else if (
|
||||
mimeType.startsWith('video/') ||
|
||||
mimeType.startsWith('audio/')
|
||||
) {
|
||||
toast.error('Video and audio files are not supported yet.');
|
||||
break;
|
||||
} else if (mimeType.startsWith('application/pdf')) {
|
||||
if (config.pdfAsImage && !isSupportVision) {
|
||||
toast(
|
||||
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
const promise =
|
||||
config.pdfAsImage && isSupportVision
|
||||
? convertPDFToImage(file).then((base64Urls) => {
|
||||
addItems(
|
||||
base64Urls.map((base64Url) => ({
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
}))
|
||||
);
|
||||
})
|
||||
: convertPDFToText(file).then((content) => {
|
||||
if (isSupportVision) {
|
||||
toast.success(
|
||||
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||
);
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
promise.catch((error) => {
|
||||
console.error(error);
|
||||
toast.error('Failed to parse PDF file.');
|
||||
});
|
||||
break;
|
||||
} else {
|
||||
// Because there can be many text file types (like code file), we will not check the mime type
|
||||
// and will just check if the file is not binary.
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
const content = event.target.result as string;
|
||||
if (!isLikelyNotBinary(content)) {
|
||||
toast.error('File is binary. Please upload a text file.');
|
||||
return;
|
||||
}
|
||||
}))
|
||||
);
|
||||
} else {
|
||||
// Convert PDF to text
|
||||
const content = await convertPDFToText(file);
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
@@ -138,10 +119,40 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
content,
|
||||
},
|
||||
]);
|
||||
if (isSupportVision) {
|
||||
toast.success(
|
||||
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
reader.readAsText(file);
|
||||
break;
|
||||
} else {
|
||||
// Because there can be many text file types (like code file), we will not check the mime type
|
||||
// and will just check if the file is not binary.
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
const content = event.target.result as string;
|
||||
if (!isLikelyNotBinary(content)) {
|
||||
toast.error('File is binary. Please upload a text file.');
|
||||
return;
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
}
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error processing file: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -154,6 +165,25 @@ export function useChatExtraContext(): ChatExtraContextApi {
|
||||
};
|
||||
}
|
||||
|
||||
async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
let result = event.target.result as string;
|
||||
if (!outputUrl) {
|
||||
// remove base64 url prefix and correct characters
|
||||
result = result.substring(result.indexOf(',') + 1);
|
||||
}
|
||||
resolve(result);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
|
||||
@@ -89,6 +89,14 @@ export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
|
||||
type: 'image_url',
|
||||
image_url: { url: extra.base64Url },
|
||||
});
|
||||
} else if (extra.type === 'audioFile') {
|
||||
contentArr.push({
|
||||
type: 'input_audio',
|
||||
input_audio: {
|
||||
data: extra.base64Data,
|
||||
format: /wav/.test(extra.mimeType) ? 'wav' : 'mp3',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
throw new Error('Unknown extra type');
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ export interface Message {
|
||||
export type MessageExtra =
|
||||
| MessageExtraTextFile
|
||||
| MessageExtraImageFile
|
||||
| MessageExtraAudioFile
|
||||
| MessageExtraContext;
|
||||
|
||||
export interface MessageExtraTextFile {
|
||||
@@ -65,6 +66,13 @@ export interface MessageExtraImageFile {
|
||||
base64Url: string;
|
||||
}
|
||||
|
||||
export interface MessageExtraAudioFile {
|
||||
type: 'audioFile';
|
||||
name: string;
|
||||
base64Data: string;
|
||||
mimeType: string;
|
||||
}
|
||||
|
||||
export interface MessageExtraContext {
|
||||
type: 'context';
|
||||
name: string;
|
||||
@@ -79,6 +87,10 @@ export type APIMessageContentPart =
|
||||
| {
|
||||
type: 'image_url';
|
||||
image_url: { url: string };
|
||||
}
|
||||
| {
|
||||
type: 'input_audio';
|
||||
input_audio: { data: string; format: 'wav' | 'mp3' };
|
||||
};
|
||||
|
||||
export type APIMessage = {
|
||||
@@ -120,6 +132,7 @@ export interface LlamaCppServerProps {
|
||||
n_ctx: number;
|
||||
modalities?: {
|
||||
vision: boolean;
|
||||
audio: boolean;
|
||||
};
|
||||
// TODO: support params
|
||||
}
|
||||
|
||||
+4
-2
@@ -579,6 +579,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
params.model = params.vocoder.model;
|
||||
params.embedding = true;
|
||||
params.ctx_shift = false; // silence warning
|
||||
params.n_ubatch = params.n_batch;
|
||||
|
||||
common_init_result llama_init_cts = common_init_from_params(params);
|
||||
|
||||
@@ -1020,8 +1022,8 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == n_codes);
|
||||
|
||||
if (llama_decode(ctx_cts, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
if (llama_encode(ctx_cts, batch) != 0) {
|
||||
LOG_ERR("%s: llama_encode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user