mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-22 21:57:40 +02:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac | |||
| 0ef6f06d55 | |||
| 52b3df0023 | |||
| 7c082bc417 | |||
| bddfd2b113 | |||
| 0d135df48c | |||
| bf533823cd | |||
| 2f89acc2bc | |||
| bfa3219177 | |||
| d6d899580d | |||
| 8a118ee86c | |||
| d789527482 | |||
| 063d9c156e | |||
| c57607016a | |||
| 4a80943174 | |||
| 84de01a1f1 | |||
| 75f460ac28 | |||
| 8452824611 | |||
| e27f308597 | |||
| 67e9fd3b74 | |||
| 796f41bedc | |||
| 37a77fb057 | |||
| f4043fec01 | |||
| f449e05537 | |||
| 2b686a9120 | |||
| 4b48a53b6c | |||
| e475fa2b5f | |||
| 175147e8f6 | |||
| fabde3bf51 |
@@ -4,20 +4,6 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
### Build Llama.cpp stage
|
||||
FROM docker.io/gcc:${GCC_VERSION} AS build
|
||||
|
||||
@@ -34,8 +20,6 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
--mount=type=cache,target=/app/build \
|
||||
cmake -S . -B build -G Ninja \
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
build*/
|
||||
|
||||
tools/ui/node_modules/
|
||||
tools/ui/dist/
|
||||
|
||||
models/*
|
||||
|
||||
|
||||
@@ -58,6 +58,13 @@ jobs:
|
||||
git tag ${{ steps.srctag.outputs.name }} || exit 0
|
||||
git push origin ${{ steps.srctag.outputs.name }} || exit 0
|
||||
|
||||
build_ui:
|
||||
name: Build UI
|
||||
needs: create_tag
|
||||
uses: ./.github/workflows/ui-build.yml
|
||||
with:
|
||||
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
prepare_matrices:
|
||||
name: Prepare Docker matrices
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -79,7 +86,7 @@ jobs:
|
||||
[
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
@@ -135,7 +142,7 @@ jobs:
|
||||
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Registry
|
||||
needs: [prepare_matrices, create_tag]
|
||||
needs: [prepare_matrices, create_tag, build_ui]
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
strategy:
|
||||
@@ -150,6 +157,13 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
- name: Download prebuilt UI
|
||||
if: ${{ matrix.config.prebuilt_ui == true }}
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
name: ui-build
|
||||
path: tools/ui/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
|
||||
|
||||
@@ -1627,6 +1627,7 @@ jobs:
|
||||
**Windows:**
|
||||
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
|
||||
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
|
||||
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
|
||||
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
|
||||
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
|
||||
|
||||
+51
-9
@@ -17,6 +17,7 @@
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <shellapi.h>
|
||||
#endif
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
@@ -302,7 +303,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo;
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
@@ -322,7 +322,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
}
|
||||
|
||||
model.name = model.hf_repo;
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
@@ -397,7 +396,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -409,6 +408,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
if (callback) {
|
||||
opts.callback = callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
@@ -585,8 +588,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
@@ -595,7 +601,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
@@ -893,7 +898,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
struct utf8_argv {
|
||||
std::vector<std::string> buf;
|
||||
std::vector<char*> ptrs;
|
||||
};
|
||||
|
||||
static utf8_argv make_utf8_argv() {
|
||||
utf8_argv out;
|
||||
int wargc = 0;
|
||||
LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc);
|
||||
if (!wargv) return out;
|
||||
|
||||
out.buf.reserve(wargc);
|
||||
for (int i = 0; i < wargc; ++i) {
|
||||
int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr);
|
||||
if (n <= 0) { out.buf.emplace_back(); continue; }
|
||||
auto& s = out.buf.emplace_back();
|
||||
s.resize(static_cast<size_t>(n - 1));
|
||||
(void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr);
|
||||
}
|
||||
LocalFree(wargv);
|
||||
|
||||
out.ptrs.reserve(out.buf.size() + 1);
|
||||
for (auto& s : out.buf) out.ptrs.push_back(s.data());
|
||||
out.ptrs.push_back(nullptr);
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
#ifdef _WIN32
|
||||
auto utf8 = make_utf8_argv();
|
||||
// repair argv only when it matches the process command line
|
||||
if (static_cast<int>(utf8.buf.size()) == argc) {
|
||||
argv = utf8.ptrs.data();
|
||||
}
|
||||
#endif
|
||||
|
||||
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
|
||||
const common_params params_org = ctx_arg.params; // the example can modify the default params
|
||||
|
||||
@@ -2861,7 +2903,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
add_opt(common_arg(
|
||||
add_opt(common_arg(
|
||||
{"-ag", "--agent"},
|
||||
{"-no-ag", "--no-agent"},
|
||||
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
|
||||
@@ -2911,7 +2953,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
|
||||
add_opt(common_arg(
|
||||
{"--api-key-file"}, "FNAME",
|
||||
"path to file containing API keys (default: none)",
|
||||
"path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::ifstream key_file(value);
|
||||
if (!key_file) {
|
||||
@@ -2919,7 +2961,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (!key.empty()) {
|
||||
if (!key.empty() && key[0] != '#') {
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -133,7 +134,10 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
p.ac(p.tool_arg_string_value(until_suffix) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
|
||||
(p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)))));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
@@ -1074,6 +1074,18 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
|
||||
return files;
|
||||
}
|
||||
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) {
|
||||
#ifdef _WIN32
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
|
||||
if (!wlen) { return std::ifstream(); }
|
||||
std::vector<wchar_t> wfname(wlen);
|
||||
(void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen);
|
||||
return std::ifstream(wfname.data(), mode);
|
||||
#else
|
||||
return std::ifstream(fname, mode);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
+13
-1
@@ -295,7 +295,16 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
|
||||
std::string get_name() {
|
||||
if (!hf_repo.empty()) {
|
||||
return hf_repo;
|
||||
}
|
||||
if (!docker_repo.empty()) {
|
||||
return docker_repo;
|
||||
}
|
||||
return path;
|
||||
}
|
||||
};
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
@@ -842,6 +851,9 @@ struct common_file_info {
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
// fs open, also handle UTF8 on Windows
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
|
||||
+89
-46
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
|
||||
const size_t expected_count = this_args.size();
|
||||
const size_t input_count = args.count();
|
||||
|
||||
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this_args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this_args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
value macro_statement::execute_impl(context & ctx) {
|
||||
if (!is_stmt<identifier>(this->name)) {
|
||||
throw std::runtime_error("Macro name must be an identifier");
|
||||
}
|
||||
std::string name = cast_stmt<identifier>(this->name)->val;
|
||||
|
||||
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
|
||||
size_t expected_count = this->args.size();
|
||||
size_t input_count = args.count();
|
||||
const func_handler func = [this, name](const func_args & args) -> value {
|
||||
context macro_ctx(args.ctx); // new scope for macro execution
|
||||
|
||||
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
context macro_ctx(ctx); // new scope for macro execution
|
||||
|
||||
// bind parameters
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this->args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this->args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
bind_parameters(name, this->args, args, macro_ctx);
|
||||
|
||||
// execute macro body
|
||||
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
|
||||
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
value call_statement::execute_impl(context & ctx) {
|
||||
auto call_expr = cast_stmt<call_expression>(this->call);
|
||||
if (!call_expr) {
|
||||
throw std::runtime_error("Call statement requires a valid call expression");
|
||||
}
|
||||
|
||||
value callee_val = call_expr->callee->execute(ctx);
|
||||
if (!is_val<value_func>(callee_val)) {
|
||||
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
|
||||
}
|
||||
auto * callee_func = cast_val<value_func>(callee_val);
|
||||
|
||||
context caller_ctx(ctx); // new scope for caller execution
|
||||
|
||||
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
|
||||
context block_ctx(caller_ctx); // new scope for block execution
|
||||
|
||||
bind_parameters("caller", this->caller_args, args, block_ctx);
|
||||
|
||||
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
|
||||
auto res = exec_statements(this->body, block_ctx);
|
||||
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
|
||||
return res;
|
||||
};
|
||||
|
||||
context call_ctx(ctx);
|
||||
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
|
||||
|
||||
func_args args(call_ctx);
|
||||
|
||||
for (const auto & arg_expr : call_expr->args) {
|
||||
auto arg_val = arg_expr->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.push_back(arg_val);
|
||||
}
|
||||
|
||||
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
|
||||
return callee_func->invoke(args);
|
||||
}
|
||||
|
||||
value member_expression::execute_impl(context & ctx) {
|
||||
value object = this->object->execute(ctx);
|
||||
|
||||
|
||||
@@ -552,6 +552,7 @@ struct call_statement : public statement {
|
||||
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
|
||||
@@ -233,27 +233,27 @@ struct BuiltinRule {
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
||||
{"boolean", {"(\"true\" | \"false\")", {}}},
|
||||
{"decimal-part", {"[0-9]{1,16}", {}}},
|
||||
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
|
||||
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
|
||||
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
||||
{"null", {"\"null\" space", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
|
||||
{"null", {"\"null\"", {}}},
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
||||
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
||||
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
@@ -551,16 +551,16 @@ private:
|
||||
}
|
||||
return join_seq();
|
||||
};
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
|
||||
}
|
||||
|
||||
/*
|
||||
Returns a rule that matches a JSON string that is none of the provided strings
|
||||
|
||||
not_strings({"a"})
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
not_strings({"and", "also"})
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
*/
|
||||
std::string _not_strings(const std::vector<std::string> & strings) {
|
||||
|
||||
@@ -619,7 +619,7 @@ private:
|
||||
if (!trie.is_end_of_string) {
|
||||
out << "?";
|
||||
}
|
||||
out << " [\"] space";
|
||||
out << " [\"]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
@@ -725,7 +725,7 @@ private:
|
||||
rule += " )?";
|
||||
}
|
||||
|
||||
rule += " \"}\" space";
|
||||
rule += " space \"}\"";
|
||||
|
||||
return rule;
|
||||
}
|
||||
@@ -858,14 +858,14 @@ public:
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
}
|
||||
if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||
}
|
||||
if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
@@ -933,7 +933,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
@@ -948,7 +948,7 @@ public:
|
||||
}
|
||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||
}
|
||||
rule += " \"]\" space";
|
||||
rule += " space \"]\"";
|
||||
return _add_rule(rule_name, rule);
|
||||
}
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
@@ -956,7 +956,7 @@ public:
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
@@ -972,7 +972,7 @@ public:
|
||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
|
||||
}
|
||||
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
@@ -990,7 +990,7 @@ public:
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
build_min_max_int(min_value, max_value, out);
|
||||
out << ") space";
|
||||
out << ")";
|
||||
return _add_rule(rule_name, out.str());
|
||||
}
|
||||
if (schema.empty() || schema_type == "object") {
|
||||
|
||||
+202
-89
@@ -6,13 +6,14 @@
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
// Trick to catch missing branches
|
||||
template <typename T>
|
||||
@@ -88,40 +89,7 @@ struct trie {
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
}
|
||||
out.emplace_back(prefix_and_next{prefix, chars});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
prefix.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
size_t create_node() {
|
||||
size_t index = nodes.size();
|
||||
nodes.emplace_back();
|
||||
@@ -153,6 +121,65 @@ struct trie {
|
||||
}
|
||||
};
|
||||
|
||||
// Aho-Corasick automaton
|
||||
struct aho_corasick {
|
||||
trie t;
|
||||
std::vector<size_t> fail; // failure links
|
||||
std::vector<size_t> order; // states in BFS order
|
||||
std::vector<bool> terminal; // match states (directly or via a suffix link)
|
||||
std::set<uint32_t> alphabet; // every character with a transition
|
||||
|
||||
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
|
||||
const auto & nodes = t.nodes;
|
||||
const size_t n = nodes.size();
|
||||
|
||||
fail.assign(n, 0);
|
||||
order.reserve(n);
|
||||
|
||||
std::deque<size_t> queue{ 0 };
|
||||
while (!queue.empty()) {
|
||||
size_t u = queue.front();
|
||||
queue.pop_front();
|
||||
order.push_back(u);
|
||||
for (const auto & [ch, v] : nodes[u].children) {
|
||||
if (u != 0) {
|
||||
size_t f = fail[u];
|
||||
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
|
||||
f = fail[f];
|
||||
}
|
||||
auto it = nodes[f].children.find(ch);
|
||||
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
|
||||
}
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
terminal.assign(n, false);
|
||||
for (size_t u : order) {
|
||||
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
|
||||
}
|
||||
|
||||
for (const auto & node : nodes) {
|
||||
for (const auto & [ch, v] : node.children) {
|
||||
alphabet.insert(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_states() const { return t.nodes.size(); }
|
||||
bool is_terminal(size_t s) const { return terminal[s]; }
|
||||
|
||||
// follow failure links until a transition on `ch` exists.
|
||||
size_t next(size_t state, uint32_t ch) const {
|
||||
const auto & nodes = t.nodes;
|
||||
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
|
||||
state = fail[state];
|
||||
}
|
||||
auto it = nodes[state].children.find(ch);
|
||||
return it != nodes[state].children.end() ? it->second : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
|
||||
if (pos + hex_count > str.length()) {
|
||||
return {0, 0};
|
||||
@@ -894,6 +921,10 @@ struct parser_executor {
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
std::set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
std::set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
|
||||
if (delimiters.empty()) {
|
||||
throw std::runtime_error("ac parser requires at least one delimiter");
|
||||
}
|
||||
return add(common_peg_ac_parser{p, delimiters});
|
||||
}
|
||||
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
trie matcher(strings);
|
||||
auto pieces = matcher.collect_prefix_and_next();
|
||||
|
||||
std::string pattern;
|
||||
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
|
||||
for (size_t i = 0; i < pieces.size(); ++i) {
|
||||
if (i > 0) {
|
||||
pattern += " | ";
|
||||
}
|
||||
|
||||
const auto & pre = pieces[i].prefix;
|
||||
const auto & chars = pieces[i].next_chars;
|
||||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
|
||||
pattern += pre_literal + " [^" + cls + "]";
|
||||
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
|
||||
// char, so the repetition alone cannot match a value that *ends* on a proper
|
||||
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
|
||||
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
|
||||
// values, so without this the grammar would reject input the parser accepts.
|
||||
// Allow the value to terminate on any proper prefix as an optional tail.
|
||||
// This makes the grammar a slight superset of the runtime language (a value
|
||||
// may end on the longest prefix, which greedy first-match would not itself
|
||||
// produce); harmless for constrained generation, which only needs to admit
|
||||
// every runtime-valid string.
|
||||
if (!trailing.empty()) {
|
||||
trailing += " | ";
|
||||
}
|
||||
trailing += pre_literal;
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
}
|
||||
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
std::string result = "(" + pattern + ")*";
|
||||
if (!trailing.empty()) {
|
||||
result += " (" + trailing + ")?";
|
||||
}
|
||||
return result;
|
||||
return s + "]";
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> collect_reachable_rules(
|
||||
static std::string gbnf_ac_grammar(
|
||||
const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings,
|
||||
const std::function<std::string(const std::vector<uint32_t> &,
|
||||
const std::map<size_t, std::vector<uint32_t>> &,
|
||||
const std::vector<uint32_t> &,
|
||||
const std::function<std::string(size_t)> &)> & build_rule) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
if (s == 0) {
|
||||
return prefix;
|
||||
}
|
||||
std::string num = std::to_string(s);
|
||||
num = num.size() == 1 ? ("0" + num) : num;
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> completing; // chars that complete a delimiter
|
||||
std::vector<uint32_t> specific; // chars with an explicit transition
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
completing.push_back(c);
|
||||
specific.push_back(c);
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
specific.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
|
||||
}
|
||||
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches the empty string so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & /*completing*/,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
// every state is accepting and completing chars get no
|
||||
// alternative, so a forbidden string can never be matched
|
||||
std::string rhs = "|";
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
|
||||
return rhs;
|
||||
});
|
||||
}
|
||||
|
||||
// GBNF grammar matching everything up to and including the first occurrence of
|
||||
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & completing,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
std::vector<std::string> alts;
|
||||
if (!completing.empty()) {
|
||||
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
|
||||
}
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
|
||||
}
|
||||
// every other character keeps scanning from the start state
|
||||
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
|
||||
return string_join(alts, " | ");
|
||||
});
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
) {
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::unordered_set<std::string> visited;
|
||||
std::set<std::string> reachable;
|
||||
std::set<std::string> visited;
|
||||
|
||||
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
|
||||
const auto & parser = arena.get(id);
|
||||
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
};
|
||||
|
||||
// Collect reachable rules
|
||||
std::unordered_set<std::string> reachable_rules;
|
||||
std::set<std::string> reachable_rules;
|
||||
|
||||
if (lazy) {
|
||||
// Collect rules reachable from trigger rules
|
||||
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "ac") {
|
||||
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
|
||||
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
|
||||
}
|
||||
return common_peg_ac_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["delimiters"].get<std::vector<std::string>>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+16
-3
@@ -3,8 +3,8 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
struct common_peg_ac_parser {
|
||||
common_peg_parser_id child;
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
common_peg_gbnf_parser,
|
||||
common_peg_ac_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -335,7 +341,7 @@ class common_peg_arena {
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
|
||||
// automaton of `delimiters`, matching everything up to and including the
|
||||
// first delimiter. Parsing delegates entirely to the child, which is
|
||||
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+102
-35
@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool is_mem_shared = false;
|
||||
// One MTP draft driver, three modes (set once in the ctor):
|
||||
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
|
||||
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
|
||||
// neither (qwen35 / qwen35moe): a single trained MTP head.
|
||||
int32_t n_mtp_layers = 1;
|
||||
bool is_mem_shared = false; // gemma4
|
||||
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
std::vector<std::vector<float>> verify_h;
|
||||
std::vector<int32_t> verify_h_rows;
|
||||
|
||||
// Per-seq draft length from the last draft() call, used in accept() to
|
||||
// roll back ctx_dft's recurrent state past the AR draft's redundant
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
std::vector<int> i_last;
|
||||
std::vector<std::vector<float>> chain_h;
|
||||
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
|
||||
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
|
||||
|
||||
if (chain_heads) {
|
||||
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
|
||||
|
||||
chain_h.assign(n_seq, {});
|
||||
for (auto & c : chain_h) {
|
||||
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_last.assign(n_seq, -1);
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
|
||||
verify_h.assign(n_seq, {});
|
||||
verify_h_rows.assign(n_seq, 0);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
bool ok = true;
|
||||
for (int head = 0; head < n_mtp_layers; ++head) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, head);
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
if (chain_heads) {
|
||||
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
|
||||
}
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
// each step decodes under a different head, i.e. a different decoder layer, and
|
||||
// KV is per layer. process() filled this layer's KV only for positions < n_past
|
||||
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
|
||||
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
|
||||
// and select head i so it rebuilds its own layer's KV there; decoding just the
|
||||
// latest token would leave its attention reading cells only another head wrote.
|
||||
if (chain_heads) {
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (drafting[seq_id]) {
|
||||
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
|
||||
}
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, i);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
// rebuild the batch for the next step: the growing-KV paths re-add only the
|
||||
// new token (the KV already holds the prefix), while chained heads re-add the
|
||||
// whole prefix at the next head. dropped sequences are simply not re-added.
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
|
||||
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_mem_shared) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
|
||||
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
|
||||
|
||||
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
|
||||
for (int t = 0; t < n_rows; ++t) {
|
||||
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
|
||||
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
|
||||
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
|
||||
}
|
||||
} else if (is_mem_shared) {
|
||||
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
|
||||
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
|
||||
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
} else {
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
}
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel):
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
|
||||
+7
-1
@@ -1119,8 +1119,10 @@ class TextModel(ModelBase):
|
||||
|
||||
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
|
||||
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
|
||||
partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True)
|
||||
original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True)
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
# Ensure global params are mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if local_rope_theta is not None:
|
||||
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
|
||||
@@ -1128,6 +1130,10 @@ class TextModel(ModelBase):
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None:
|
||||
self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
||||
if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None:
|
||||
self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
|
||||
@@ -148,7 +148,7 @@ class ChatGLMModel(TextModel):
|
||||
rope_dim = self.hparams["attention_dim"]
|
||||
else:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
rope_freq = 10000
|
||||
if "rope_ratio" in self.hparams:
|
||||
|
||||
+1
-1
@@ -161,7 +161,7 @@ class DeciModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
@@ -24,7 +24,7 @@ class ExaoneModel(TextModel):
|
||||
|
||||
assert (hparams["activation_function"] == "silu")
|
||||
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
|
||||
rotary_factor = self.rope_parameters.get("partial_rotary_factor")
|
||||
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
|
||||
@@ -39,7 +39,7 @@ class ExaoneModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
@@ -104,7 +104,7 @@ class Exaone4Model(TextModel):
|
||||
factor = rope_params.get("factor", 16.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model):
|
||||
self.gguf_writer.add_head_count_kv(value_arr)
|
||||
|
||||
# handle n_rot differently for global vs swa layers
|
||||
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
|
||||
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
|
||||
self.gguf_writer.add_rope_dimension_count(n_rot_full)
|
||||
|
||||
+2
-2
@@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel):
|
||||
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
)
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
|
||||
int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))
|
||||
)
|
||||
|
||||
# MoE parameters - Use only routed expert count (shared experts handled separately)
|
||||
@@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
rope_dim = self.hparams["qk_rope_head_dim"]
|
||||
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
|
||||
|
||||
# NextN/MTP prediction layers
|
||||
|
||||
+1
-1
@@ -289,7 +289,7 @@ class LlamaModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -154,7 +154,7 @@ class MimoV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
|
||||
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
|
||||
rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"])
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
|
||||
|
||||
+6
-10
@@ -32,11 +32,9 @@ class MiniCPMModel(TextModel):
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
@@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel):
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
rope_dims = self.hparams["qk_rope_head_dim"]
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
@@ -125,17 +125,18 @@ class NemotronModel(TextModel):
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
|
||||
# * Partial RoPE
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||
|
||||
# * RopeScaling for Nemotron
|
||||
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
|
||||
factor = self.hparams.get("factor") or self.rope_parameters.get("factor")
|
||||
if factor is None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
|
||||
self.gguf_writer.add_rope_scaling_factor(factor)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
|
||||
|
||||
+9
-11
@@ -18,7 +18,7 @@ class Phi2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
|
||||
@@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel):
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
rms_eps = self.find_hparam(["rms_norm_eps"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
self.gguf_writer.add_context_length(max_pos_embds)
|
||||
@@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel):
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
# write rope scaling for long context (128k) model
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if not long_factors:
|
||||
return
|
||||
|
||||
scale = max_pos_embds / orig_max_pos_embds
|
||||
|
||||
rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower()
|
||||
rope_scaling_type = self.rope_parameters.get('rope_type', '').lower()
|
||||
if len(rope_scaling_type) == 0:
|
||||
raise KeyError('Missing the required key rope_scaling.type')
|
||||
|
||||
@@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
+1
-1
@@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel):
|
||||
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
|
||||
@@ -28,7 +28,7 @@ class StableLMModel(TextModel):
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
rotary_factor = self.rope_parameters["partial_rotary_factor"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
|
||||
+1
-1
@@ -314,7 +314,7 @@ class Step35Model(TextModel):
|
||||
factor = float(rope_params.get("factor", 8.0))
|
||||
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
|
||||
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", 8192))
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
|
||||
|
||||
```
|
||||
$ apt update && apt upgrade -y
|
||||
$ apt install git cmake
|
||||
$ apt install git cmake libandroid-spawn
|
||||
```
|
||||
|
||||
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
|
||||
|
||||
@@ -198,18 +198,18 @@ class BuiltinRule:
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'boolean' : BuiltinRule('("true" | "false")', []),
|
||||
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
|
||||
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
|
||||
'null' : BuiltinRule('"null"', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
@@ -319,7 +319,7 @@ class SchemaConverter:
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
@@ -549,7 +549,7 @@ class SchemaConverter:
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
@@ -580,10 +580,10 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
@@ -624,7 +624,7 @@ class SchemaConverter:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
@@ -638,12 +638,12 @@ class SchemaConverter:
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
' space "]"')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
@@ -663,7 +663,7 @@ class SchemaConverter:
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
@@ -680,7 +680,7 @@ class SchemaConverter:
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
out.append(")")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
@@ -765,7 +765,7 @@ class SchemaConverter:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
rule += ' space "}"'
|
||||
|
||||
return rule
|
||||
|
||||
|
||||
@@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
||||
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
||||
|
||||
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
parallel_for_ggml(params, n_batch * M, [&](int begin, int end) {
|
||||
for (int idx = begin; idx < end; ++idx) {
|
||||
int batch_idx = idx / M;
|
||||
int m = idx % M;
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
|
||||
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_inner_stride,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
src1_T[j * d_inner_stride + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_stride = scctx->nrows_per_thread;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
|
||||
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
|
||||
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
#endif
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
|
||||
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
|
||||
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_abs(T x) {
|
||||
return sycl::fabs(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
|
||||
} else {
|
||||
return sycl::fabs(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::expm1(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::expm1(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_elu(T x) {
|
||||
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
||||
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
constexpr int ver = __INTEL_LLVM_COMPILER;
|
||||
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
|
||||
return sycl::ext::oneapi::experimental::tanh(x);
|
||||
#else
|
||||
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
|
||||
#endif
|
||||
} else {
|
||||
return sycl::tanh(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
|
||||
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
||||
return static_cast<T>(0.5f) * x *
|
||||
(static_cast<T>(1.0f) +
|
||||
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::exp(x);
|
||||
} else {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_silu(T x) {
|
||||
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return x / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
static __dpct_inline__ T op_erf(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::erf(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::erf(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_erf(T x) {
|
||||
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
return sycl::tanh(x);
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_relu(T x) {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sigmoid(T x) {
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sqrt(T x) {
|
||||
return sycl::sqrt(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sqrt(x);
|
||||
} else {
|
||||
return sycl::sqrt(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sin(T x) {
|
||||
return sycl::sin(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sin(x);
|
||||
} else {
|
||||
return sycl::sin(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_cos(T x) {
|
||||
return sycl::cos(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::cos(x);
|
||||
} else {
|
||||
return sycl::cos(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardsigmoid(T x) {
|
||||
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmin(
|
||||
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
|
||||
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return sycl::fmin(static_cast<T>(1.0f),
|
||||
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardswish(T x) {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
return sycl::expm1(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
|
||||
if (x <= static_cast<T>(0)) {
|
||||
return neg_infinity<T>();
|
||||
}
|
||||
return sycl::log(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::log(x);
|
||||
} else {
|
||||
return sycl::log(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float ax = op_abs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
||||
T neg_slope_T = static_cast<T>(negative_slope);
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
|
||||
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::floor(x);
|
||||
} else {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::ceil(x);
|
||||
} else {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::round(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::round(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::trunc(x);
|
||||
} else {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename F>
|
||||
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
||||
});
|
||||
}, negative_slope);
|
||||
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
||||
});
|
||||
}, min_val, max_val);
|
||||
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
||||
const int64_t n, const int64_t o0, const int64_t o1,
|
||||
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
|
||||
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3788,7 +3788,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
||||
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
@@ -3800,17 +3800,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
ctx->webgpu_global_ctx->adapter = std::move(adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
instance.WaitAny(instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter);
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
|
||||
|
||||
ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
|
||||
@@ -4543,20 +4546,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx->webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter);
|
||||
}
|
||||
|
||||
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
|
||||
|
||||
+7
-10
@@ -600,18 +600,15 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
|
||||
// convert fname (UTF-8)
|
||||
wchar_t * wfname = ggml_mbstowcs(fname);
|
||||
if (wfname) {
|
||||
// convert mode (ANSI)
|
||||
wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
|
||||
wchar_t * wmode_p = wmode;
|
||||
do {
|
||||
*wmode_p++ = (wchar_t)*mode;
|
||||
} while (*mode++);
|
||||
|
||||
// open file
|
||||
file = _wfopen(wfname, wmode);
|
||||
// convert mode (UTF-8)
|
||||
wchar_t * wmode = ggml_mbstowcs(mode);
|
||||
if (wmode) {
|
||||
// open file
|
||||
file = _wfopen(wfname, wmode);
|
||||
GGML_FREE(wmode);
|
||||
}
|
||||
|
||||
GGML_FREE(wfname);
|
||||
GGML_FREE(wmode);
|
||||
}
|
||||
|
||||
return file;
|
||||
|
||||
+9
-8
@@ -558,14 +558,15 @@ extern "C" {
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_nextn_layer_offset(int32_t offset) {
|
||||
cparams.nextn_layer_offset = offset;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
|
||||
ctx->set_embeddings_layer_inp(lid, value);
|
||||
}
|
||||
|
||||
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
|
||||
ctx->set_nextn_layer_offset(offset);
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
|
||||
@@ -115,6 +115,7 @@ struct llama_context {
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_embeddings_layer_inp(uint32_t lid, bool enable);
|
||||
void set_nextn_layer_offset(int32_t offset);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ struct llama_cparams {
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
int32_t nextn_layer_offset = 0;
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
|
||||
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// Select which appended NextN block the DECODER_MTP graph runs (offset past
|
||||
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
|
||||
// chain multiple trained NextN heads. Default 0 (first head).
|
||||
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
+9
-2
@@ -682,9 +682,16 @@ struct llm_graph_params {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
|
||||
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
|
||||
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
|
||||
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
|
||||
return model->hparams.n_layer();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_layer_nextn(const llama_model * model) {
|
||||
return model->hparams.n_layer_nextn;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_head(const llama_model * model) {
|
||||
return model->hparams.n_head();
|
||||
}
|
||||
|
||||
+2
-2
@@ -932,8 +932,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
// copy the KV pairs from the input file
|
||||
gguf_set_kv (ctx_out.get(), ml.metadata);
|
||||
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
||||
gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV
|
||||
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION).c_str(), GGML_QNT_VERSION);
|
||||
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_GENERAL_FILE_TYPE).c_str(), ftype);
|
||||
|
||||
// Remove split metadata
|
||||
gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
|
||||
|
||||
@@ -2813,8 +2813,6 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_softmax_impl(cur_p, true);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
|
||||
|
||||
@@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) {
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags);
|
||||
|
||||
// DSA indexer
|
||||
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags);
|
||||
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags);
|
||||
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags);
|
||||
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags);
|
||||
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags);
|
||||
layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED);
|
||||
if (i < (int) hparams.n_layer_dense_lead) {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
|
||||
|
||||
+27
-28
@@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
};
|
||||
|
||||
auto load_block_mtp = [&](int i, bool is_first_mtp) {
|
||||
auto load_block_mtp = [&](int i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
const uint32_t n_head_l = hparams.n_head(i);
|
||||
@@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
|
||||
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
|
||||
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
|
||||
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
|
||||
//
|
||||
// Only the FIRST MTP block (i == n_main) is required for the
|
||||
// single-block MTP runtime; trailing MTP blocks are always tolerated
|
||||
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
|
||||
// mtp_flags to NOT_REQUIRED for those.
|
||||
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
|
||||
// Multi-block MTP: every declared MTP block is required (the draft chain
|
||||
// runs all n_layer_nextn heads), so each block uses the captured
|
||||
// `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF,
|
||||
// which keeps that path correct.
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
@@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
|
||||
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
|
||||
|
||||
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
|
||||
|
||||
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||
@@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
// NextN-specific tensors that define the MTP block.
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
|
||||
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
|
||||
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
|
||||
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
|
||||
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
|
||||
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
|
||||
@@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
load_block_trunk(i, trunk_flags);
|
||||
}
|
||||
// Only the first MTP block (i == n_main) is required at runtime — the
|
||||
// single-block-MTP graph in build_arch_graph always uses that one.
|
||||
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
|
||||
// all MTP layers still works) but tolerated when absent via the pruning
|
||||
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
|
||||
// All n_layer_nextn MTP blocks are required — the multi-block draft chain
|
||||
// runs every head (head k at offset k). The GGUF declares the count via
|
||||
// step35.nextn_predict_layers.
|
||||
for (int i = n_layer; i < n_layer_all; ++i) {
|
||||
load_block_mtp(i, /*is_first_mtp=*/ i == n_layer);
|
||||
load_block_mtp(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
: llm_graph_context(params) {
|
||||
GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0");
|
||||
|
||||
// Single-block MTP only: always run the first trained MTP block (Qwen
|
||||
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
|
||||
// be a much deeper refactor than this PR justifies; the trailing MTP
|
||||
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
|
||||
// block 0) also work — see load_arch_tensors below and
|
||||
// scripts/prune_step35_extra_mtp.py.
|
||||
const int il = hparams.n_layer();
|
||||
// Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by
|
||||
// cparams.nextn_layer_offset (0 = first trained head). The speculative driver
|
||||
// bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps
|
||||
// single-block behavior identical to before.
|
||||
const int il = hparams.n_layer() + cparams.nextn_layer_offset;
|
||||
GGML_ASSERT(cparams.nextn_layer_offset >= 0 &&
|
||||
cparams.nextn_layer_offset < (int) hparams.n_layer_nextn &&
|
||||
"nextn_layer_offset out of range [0, n_layer_nextn)");
|
||||
const auto & layer = model.layers[il];
|
||||
|
||||
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
|
||||
@@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
@@ -129,7 +129,154 @@ void test_gbnf_generation(testing &t) {
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar overlapping delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("\n</parameter>\n");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [\n] until-0-01 | [^\n] until-0
|
||||
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
|
||||
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
|
||||
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
|
||||
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
|
||||
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
|
||||
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
|
||||
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
|
||||
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
|
||||
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
|
||||
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
|
||||
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
|
||||
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
|
||||
until-0-13 ::= | [^\n] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
// DeepSeek-V3.2 tag prefix. The DSML token (|DSML|) embeds U+FF5C,
|
||||
// so the delimiter mixes ASCII and multi-byte codepoints.
|
||||
t.test("until grammar unicode delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until("<|DSML|");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [<] until-0-01 | [^<] until-0
|
||||
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
|
||||
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
|
||||
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
|
||||
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
|
||||
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
|
||||
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("until grammar multiple delimiters", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.until_one_of({"ab", "cd", "ef"});
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
root ::= until-0
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
|
||||
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
|
||||
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
|
||||
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.until("</tag>") + p.literal("</tag>"), "</tag>");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-3 ::= [<] ac-3-01 | [^<] ac-3
|
||||
ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3
|
||||
ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^<t] ac-3
|
||||
ac-3-03 ::= [<] ac-3-01 | [a] ac-3-04 | [^<a] ac-3
|
||||
ac-3-04 ::= [<] ac-3-01 | [g] ac-3-05 | [^<g] ac-3
|
||||
ac-3-05 ::= [>] | [<] ac-3-01 | [^<>] ac-3
|
||||
root ::= ac-3
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar terminates at first delimiter", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.until("\n</parameter>\n") + p.literal("\n</parameter>\n"), "\n</parameter>\n");
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-3 ::= [\n] ac-3-01 | [^\n] ac-3
|
||||
ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3
|
||||
ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3
|
||||
ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3
|
||||
ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3
|
||||
ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3
|
||||
ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3
|
||||
ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3
|
||||
ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3
|
||||
ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3
|
||||
ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3
|
||||
ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3
|
||||
ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3
|
||||
ac-3-13 ::= [\n] | [^\n] ac-3
|
||||
root ::= ac-3
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
t.test("ac grammar multiple delimiters", [](testing &t) {
|
||||
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
|
||||
return p.ac(p.eps(), std::vector<std::string>{"ab", "cd", "ef"});
|
||||
});
|
||||
|
||||
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
|
||||
parser.build_grammar(builder);
|
||||
});
|
||||
|
||||
assert_gbnf_equal(t, R"""(
|
||||
ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1
|
||||
ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1
|
||||
ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1
|
||||
ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1
|
||||
root ::= ac-1
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""", gbnf);
|
||||
});
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#undef NDEBUG
|
||||
#include <cassert>
|
||||
|
||||
int main(void) {
|
||||
static void test(void) {
|
||||
common_params params;
|
||||
|
||||
printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
|
||||
@@ -210,3 +210,13 @@ int main(void) {
|
||||
|
||||
printf("test-arg-parser: all tests OK\n\n");
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
try {
|
||||
test();
|
||||
} catch (std::exception & e) {
|
||||
fprintf(stderr, "test-arg-parser: exception: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
+2
-2
@@ -5022,14 +5022,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run();
|
||||
|
||||
tst.test(
|
||||
"```json\n\"42\" \n```")
|
||||
"```json\n\"42\"\n```")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.json_schema(const_schema)
|
||||
.expect_content(R"("42")")
|
||||
.run();
|
||||
|
||||
tst.test(
|
||||
"\"42\" \n")
|
||||
"\"42\"\n")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.json_schema(const_schema)
|
||||
.expect_content(R"("42")")
|
||||
|
||||
@@ -995,6 +995,32 @@ static void test_macros(testing & t) {
|
||||
json::object(),
|
||||
"Hello, John Smith,Hi, Jane Doe"
|
||||
);
|
||||
|
||||
test_template(t, "macro with caller",
|
||||
"\
|
||||
{%- macro nest_dict(o, i, ff='') %}\n\
|
||||
{{- caller(ff) }}\n\
|
||||
{%- for k, v in o|items %}\n\
|
||||
{{- i + k + ': ' }}\n\
|
||||
{%- if v is mapping %}\n\
|
||||
{{- '{' }}\n\
|
||||
{% call(f) nest_dict(v, i + ' ') %}\n\
|
||||
{{- 'fail' if ff is undefined }}\n\
|
||||
{%- endcall %}\n\
|
||||
{{- i + '}' }}\n\
|
||||
{% else %}\n\
|
||||
{{- v|string }}\n\
|
||||
{% endif %}\n\
|
||||
{%- endfor %}\n\
|
||||
{%- endmacro %}\n\
|
||||
{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\
|
||||
{{- 'fail' if ff is defined }}\n\
|
||||
{{- f + ' {' }}\n\
|
||||
{% endcall %}\n\
|
||||
{{- '}' }}",
|
||||
json::object(),
|
||||
"Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}"
|
||||
);
|
||||
}
|
||||
|
||||
static void test_namespace(testing & t) {
|
||||
|
||||
@@ -92,7 +92,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 0
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ([0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -105,7 +105,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 1
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-9] [0-9]{0,15}) space
|
||||
root ::= ([1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -118,7 +118,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 3
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15}) space
|
||||
root ::= ([1-2] [0-9]{1,15} | [3-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -131,7 +131,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 9
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15}) space
|
||||
root ::= ([1-8] [0-9]{1,15} | [9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -144,7 +144,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 10
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15}) space
|
||||
root ::= ([1] ([0-9]{1,15}) | [2-9] [0-9]{1,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -157,7 +157,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": 25
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15}) space
|
||||
root ::= ([1] [0-9]{2,15} | [2] ([0-4] [0-9]{1,14} | [5-9] [0-9]{0,14}) | [3-9] [0-9]{1,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -170,7 +170,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 30
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-2] [0-9] | [3] "0"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -183,7 +183,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": -5
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ("-" ([0-5]) | [0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -196,7 +196,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minimum": -123
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15}) space
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0] | [1-9] [0-9]{0,15})
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -209,7 +209,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": -5
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15})) space
|
||||
root ::= ("-" ([0-4] [0-9]{1,15} | [5-9] [0-9]{0,15}))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -222,7 +222,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 1
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-1]) space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-1])
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -235,7 +235,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 100
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100") space
|
||||
root ::= ("-" [1-9] [0-9]{0,15} | [0-9] | ([1-8] [0-9] | [9] [0-9]) | "100")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -249,7 +249,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 23
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([0-9] | ([1] [0-9] | [2] [0-3])) space
|
||||
root ::= ([0-9] | ([1] [0-9] | [2] [0-3]))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -263,7 +263,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 300
|
||||
})""",
|
||||
R"""(
|
||||
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00")) space
|
||||
root ::= (([1] ([5-9]) | [2-9] [0-9]) | ([1-2] [0-9]{2} | [3] "00"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -277,7 +277,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 30
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ([5-9] | ([1-2] [0-9] | [3] "0")) space
|
||||
root ::= ([5-9] | ([1-2] [0-9] | [3] "0"))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -291,7 +291,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 42
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2])) space
|
||||
root ::= ("-" ([0-9] | ([1-8] [0-9] | [9] [0-9]) | "1" ([0-1] [0-9] | [2] [0-3])) | [0-9] | ([1-3] [0-9] | [4] [0-2]))
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -305,7 +305,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maximum": 10
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("-" ([0-9] | "10") | [0-9] | "10") space
|
||||
root ::= ("-" ([0-9] | "10") | [0-9] | "10")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -333,17 +333,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"empty schema (object)",
|
||||
"{}",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -361,17 +361,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
date ::= [0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )
|
||||
date-string ::= "\"" date "\"" space
|
||||
date-string ::= "\"" date "\""
|
||||
date-time ::= date "T" time
|
||||
date-time-string ::= "\"" date-time "\"" space
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 "]" space
|
||||
date-time-string ::= "\"" date-time "\""
|
||||
root ::= "[" space tuple-0 "," space uuid "," space tuple-2 "," space tuple-3 space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
time ::= ([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )
|
||||
time-string ::= "\"" time "\"" space
|
||||
time-string ::= "\"" time "\""
|
||||
tuple-0 ::= date-string
|
||||
tuple-2 ::= time-string
|
||||
tuple-3 ::= date-time-string
|
||||
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space
|
||||
uuid ::= "\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -383,7 +383,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char* "\"" space
|
||||
root ::= "\"" char* "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -397,7 +397,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char+ "\"" space
|
||||
root ::= "\"" char+ "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -411,7 +411,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{3,} "\"" space
|
||||
root ::= "\"" char{3,} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -425,7 +425,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{0,3} "\"" space
|
||||
root ::= "\"" char{0,3} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -440,7 +440,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "\"" char{1,4} "\"" space
|
||||
root ::= "\"" char{1,4} "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -452,7 +452,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "boolean"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("true" | "false") space
|
||||
root ::= ("true" | "false")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -465,7 +465,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
root ::= ("-"? integral-part)
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -477,7 +477,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"const": "foo"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"foo\"" space
|
||||
root ::= "\"foo\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -489,7 +489,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"const": 123
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "123" space
|
||||
root ::= "123"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -501,7 +501,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"enum": ["red", "amber", "green", null, 42, ["foo"]]
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]") space
|
||||
root ::= ("\"red\"" | "\"amber\"" | "\"green\"" | "null" | "42" | "[\"foo\"]")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -515,9 +515,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space (string ("," space string)*)? "]" space
|
||||
root ::= "[" space (string ("," space string)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -529,12 +529,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
alternative-0 ::= "[" space (string ("," space string)*)? "]" space
|
||||
alternative-0 ::= "[" space (string ("," space string)*)? space "]"
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
null ::= "null" space
|
||||
null ::= "null"
|
||||
root ::= alternative-0 | null
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -546,9 +546,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "[" space string "]" space
|
||||
root ::= "[" space string space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -562,10 +562,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space string "," space number "]" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "[" space string "," space number space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -577,18 +577,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"items": {}
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= "[" space (item ("," space item)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -602,18 +602,18 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= "[" space (item ("," space item)*)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -627,7 +627,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -642,8 +642,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"minItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean ("," space boolean)+ "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space boolean ("," space boolean)+ space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -658,8 +658,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 0
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -674,8 +674,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 1
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space boolean? "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space boolean? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -690,8 +690,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 2
|
||||
})""",
|
||||
R"""(
|
||||
boolean ::= ("true" | "false") space
|
||||
root ::= "[" space (boolean ("," space boolean)?)? "]" space
|
||||
boolean ::= ("true" | "false")
|
||||
root ::= "[" space (boolean ("," space boolean)?)? space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -708,11 +708,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= number | integer
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -730,8 +730,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 5
|
||||
})""",
|
||||
R"""(
|
||||
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
item ::= ("-" ([0-9] | "1" [0-2]) | [0-9] | ([1-8] [0-9] | [9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -749,8 +749,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"maxItems": 5
|
||||
})""",
|
||||
R"""(
|
||||
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7])) space
|
||||
root ::= "[" space item ("," space item){2,4} "]" space
|
||||
item ::= (([1] ([2-9]) | [2-9] [0-9]) | ([1] [0-9]{2} | [2] "0" [0-7]))
|
||||
root ::= "[" space item ("," space item){2,4} space "]"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -763,7 +763,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^abc?d*efg+(hij)?kl$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\"" space
|
||||
root ::= "\"" ("ab" "c"? "d"* "ef" "g"+ ("hij")? "kl") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -776,7 +776,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^\\[\\]\\{\\}\\(\\)\\|\\+\\*\\?$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("[]{}()|+*?") "\"" space
|
||||
root ::= "\"" ("[]{}()|+*?") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -789,7 +789,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^\"$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("\"") "\"" space
|
||||
root ::= "\"" ("\"") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -802,7 +802,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"pattern": "^A|B|C|D$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ("A" | "B" | "C" | "D") "\"" space
|
||||
root ::= "\"" ("A" | "B" | "C" | "D") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -816,7 +816,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
dot ::= [^\x0A\x0D]
|
||||
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\"" space
|
||||
root ::= "\"" (("(" root-1{1,3} ")")? root-1{3,3} "-" root-1{4,4} " " "a"{3,5} "nd" dot dot dot) "\""
|
||||
root-1 ::= [0-9]
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
@@ -845,9 +845,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b-kv ::= "\"b\"" space ":" space string
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv "}" space
|
||||
root ::= "{" space b-kv "," space c-kv "," space a-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -865,9 +865,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv )? "}" space
|
||||
root ::= "{" space (a-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -889,9 +889,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b-rest ::= ( "," space c-kv )?
|
||||
c-kv ::= "\"c\"" space ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? "}" space
|
||||
root ::= "{" space (a-kv a-rest | b-kv b-rest | c-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -915,9 +915,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
d-kv ::= "\"d\"" space ":" space string
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
root ::= "{" space b-kv "," space a-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -930,14 +930,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
additional-kv ::= string ":" space additional-value
|
||||
additional-value ::= "[" space (number ("," space number)*)? "]" space
|
||||
additional-value ::= "[" space (number ("," space number)*)? space "]"
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (additional-kv ( "," space additional-kv )* )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -949,17 +949,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": true
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -971,17 +971,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= object
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -994,7 +994,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": false
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "{" space "}" space
|
||||
root ::= "{" space space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1012,15 +1012,15 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space string
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space a-kv ( "," space ( additional-kv ( "," space additional-kv )* ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1037,13 +1037,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
a-kv ::= "\"a\"" space ":" space number
|
||||
a-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space number
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1061,7 +1061,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"additionalProperties": {"type": "number"}
|
||||
})""",
|
||||
R"""(
|
||||
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space number
|
||||
also-kv ::= "\"also\"" space ":" space number
|
||||
also-rest ::= ( "," space additional-kv )*
|
||||
@@ -1069,8 +1069,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space and-kv ( "," space ( also-kv also-rest | additional-kv ( "," space additional-kv )* ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1090,13 +1090,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
-rest ::= ( "," space a-kv )? a-rest
|
||||
a-kv ::= "\"a\"" space ":" space integer
|
||||
a-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["] space
|
||||
additional-k ::= ["] ( [a] char+ | [^"a] char* ) ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= ("-"? integral-part) space
|
||||
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= ("-"? integral-part)
|
||||
root0 ::= "{" space (-kv -rest | a-kv a-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1116,12 +1116,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
a-rest ::= ( "," space aa-kv )? aa-rest
|
||||
aa-kv ::= "\"aa\"" space ":" space integer
|
||||
aa-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([a] char+ | [^"a] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= "{" space (a-kv a-rest | aa-kv aa-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1141,12 +1141,12 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
ab-rest ::= ( "," space ac-kv )? ac-rest
|
||||
ac-kv ::= "\"ac\"" space ":" space integer
|
||||
ac-rest ::= ( "," space additional-kv )*
|
||||
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["] space
|
||||
additional-k ::= ["] ( [a] ([b] char+ | [c] char+ | [^"bc] char*) | [^"a] char* )? ["]
|
||||
additional-kv ::= additional-k ":" space integer
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
integer ::= ("-"? integral-part) space
|
||||
integer ::= ("-"? integral-part)
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? "}" space
|
||||
root ::= "{" space (ab-kv ab-rest | ac-kv ac-rest | additional-kv ( "," space additional-kv )* )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1173,11 +1173,11 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
})""",
|
||||
R"""(
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv "}" space
|
||||
ref-definitions-foo ::= "{" space ref-definitions-foo-a-kv space "}"
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space string
|
||||
root ::= ref-definitions-foo
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1204,10 +1204,10 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
alternative-1 ::= ref-definitions-bar
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
ref-definitions-bar ::= "{" space (ref-definitions-bar-b-kv )? space "}"
|
||||
ref-definitions-bar-b-kv ::= "\"b\"" space ":" space number
|
||||
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? "}" space
|
||||
ref-definitions-foo ::= "{" space (ref-definitions-foo-a-kv )? space "}"
|
||||
ref-definitions-foo-a-kv ::= "\"a\"" space ":" space number
|
||||
root ::= alternative-0 | alternative-1
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
@@ -1241,14 +1241,14 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
b ::= b-0 | boolean
|
||||
b-0 ::= string
|
||||
b-kv ::= "\"b\"" space ":" space b
|
||||
boolean ::= ("true" | "false") space
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space (a-kv a-rest | b-kv )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space (a-kv a-rest | b-kv )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
)"""
|
||||
});
|
||||
|
||||
@@ -1290,8 +1290,8 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
d-rest ::= ( "," space c-kv )?
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
root ::= "{" space a-kv "," space b-kv ( "," space ( d-kv d-rest | c-kv ) )? space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1311,7 +1311,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"a\"" | "\"b\"") space
|
||||
root ::= ("\"a\"" | "\"b\"")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1336,7 +1336,7 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
}
|
||||
})""",
|
||||
R"""(
|
||||
root ::= ("\"b\"" | "\"c\"") space
|
||||
root ::= ("\"b\"" | "\"c\"")
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1378,13 +1378,13 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
R"""(
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
number- ::= "{" space number-number-kv "}" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
number- ::= "{" space number-number-kv space "}"
|
||||
number-kv ::= "\"number\"" space ":" space number-
|
||||
number-number ::= "{" space number-number-root-kv "}" space
|
||||
number-number ::= "{" space number-number-root-kv space "}"
|
||||
number-number-kv ::= "\"number\"" space ":" space number-number
|
||||
number-number-root-kv ::= "\"root\"" space ":" space number
|
||||
root ::= "{" space number-kv "}" space
|
||||
root ::= "{" space number-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1394,17 +1394,17 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"description only (no type) treated as unconstrained",
|
||||
R"""({"description": "The 0-based index of the last line to be retrieved (inclusive). If None, read until the end of the file."})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
array ::= "[" space ( value ("," space value)* )? space "]"
|
||||
boolean ::= ("true" | "false")
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
null ::= "null"
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? space "}"
|
||||
root ::= value
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
string ::= "\"" char* "\""
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
@@ -1428,9 +1428,9 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
"type": "object"
|
||||
})""",
|
||||
R"""(
|
||||
code ::= "\" \\r \\n \\\" \\\\ \"" space
|
||||
code ::= "\" \\r \\n \\\" \\\\ \""
|
||||
code-kv ::= "\"code\"" space ":" space code
|
||||
root ::= "{" space code-kv "}" space
|
||||
root ::= "{" space code-kv space "}"
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)"""
|
||||
});
|
||||
@@ -1547,7 +1547,7 @@ int main() {
|
||||
"pattern": "^(?:foo|bar)baz$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" (("foo" | "bar") "baz") "\"" space
|
||||
root ::= "\"" (("foo" | "bar") "baz") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
@@ -1560,7 +1560,7 @@ int main() {
|
||||
"pattern": "^(?:(?:ab)+c)?d$"
|
||||
})""",
|
||||
R"""(
|
||||
root ::= "\"" ((("ab")+ "c")? "d") "\"" space
|
||||
root ::= "\"" ((("ab")+ "c")? "d") "\""
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
)""",
|
||||
});
|
||||
|
||||
@@ -360,9 +360,9 @@ int main(void) {
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.032727f, 0.241818f, 0.241818f}, 2.0f, 1.1f, 2, 5, {});
|
||||
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
|
||||
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.0f, 0.0f, 0.428571f, 0.571429f}, 1.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 0.00f); // top_n_sigma == 0 now represents a no-op rather than greedy decoding as of PR#13345
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
|
||||
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.1f, 0.2f, 0.3f, 0.4f}, 3.00f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
|
||||
+1
-1
@@ -202,7 +202,7 @@ struct cli_context {
|
||||
|
||||
// TODO: support remote files in the future (http, https, etc)
|
||||
std::string load_input_file(const std::string & fname, bool is_media) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return "";
|
||||
}
|
||||
|
||||
@@ -13,6 +13,14 @@
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <fstream>
|
||||
|
||||
#ifdef _WIN32
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
// Internal header for clip.cpp
|
||||
|
||||
@@ -661,6 +669,22 @@ struct clip_image_f32_batch {
|
||||
// common utils
|
||||
//
|
||||
|
||||
#ifdef _WIN32
|
||||
static std::ifstream open_ifstream_binary(const std::string & fname) {
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
|
||||
if (!wlen) {
|
||||
throw std::runtime_error("failed to convert filename to UTF-16: " + fname);
|
||||
}
|
||||
std::vector<wchar_t> wfname(wlen);
|
||||
(void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen);
|
||||
return std::ifstream(wfname.data(), std::ios::binary);
|
||||
}
|
||||
#else
|
||||
static std::ifstream open_ifstream_binary(const std::string & fname) {
|
||||
return std::ifstream(fname, std::ios::binary);
|
||||
}
|
||||
#endif
|
||||
|
||||
static std::string string_format(const char * fmt, ...) {
|
||||
va_list ap;
|
||||
va_list ap2;
|
||||
|
||||
+60
-25
@@ -1045,8 +1045,17 @@ struct clip_model_loader {
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
|
||||
mtmd_progress_callback progress_callback = nullptr;
|
||||
void * progress_callback_user_data = nullptr;
|
||||
|
||||
// TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
|
||||
clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) {
|
||||
clip_model_loader(const char * fname,
|
||||
bool skip_tensors = false,
|
||||
mtmd_progress_callback progress_cb = nullptr,
|
||||
void * progress_user_data = nullptr)
|
||||
: fname(fname),
|
||||
progress_callback(progress_cb),
|
||||
progress_callback_user_data(progress_user_data) {
|
||||
struct ggml_context * meta = nullptr;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
@@ -1752,7 +1761,7 @@ struct clip_model_loader {
|
||||
std::map<std::string, size_t> tensor_offset;
|
||||
std::vector<ggml_tensor *> tensors_to_load;
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
auto fin = open_ifstream_binary(fname);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
|
||||
}
|
||||
@@ -2787,37 +2796,60 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
// load data
|
||||
if (!ctx_clip.no_alloc) {
|
||||
{
|
||||
std::vector<uint8_t> read_buf;
|
||||
|
||||
// start loading event
|
||||
if (progress_callback){
|
||||
progress_callback(0.0, progress_callback_user_data);
|
||||
}
|
||||
|
||||
// compute total tensor data size for progress reporting
|
||||
size_t total_data_size = 0;
|
||||
for (auto & t : tensors_to_load) {
|
||||
total_data_size += ggml_nbytes(t);
|
||||
}
|
||||
|
||||
// alloc memory and offload data
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
|
||||
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
|
||||
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
auto it_off = tensor_offset.find(t->name);
|
||||
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
|
||||
const size_t offset = it_off->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
|
||||
}
|
||||
size_t num_bytes = ggml_nbytes(cur);
|
||||
if (ggml_backend_buft_is_host(buft)) {
|
||||
// for the CPU and Metal backend, we can read directly into the tensor
|
||||
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
|
||||
} else {
|
||||
// read into a temporary buffer first, then copy to device memory
|
||||
read_buf.resize(num_bytes);
|
||||
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
|
||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||
// read the weight from file
|
||||
if (!ctx_clip.no_alloc) {
|
||||
size_t data_loaded = 0;
|
||||
for (auto & t : tensors_to_load) {
|
||||
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
|
||||
GGML_ASSERT(cur && "tensor not found in ctx_data");
|
||||
auto it_off = tensor_offset.find(t->name);
|
||||
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
|
||||
const size_t offset = it_off->second;
|
||||
fin.seekg(offset, std::ios::beg);
|
||||
if (!fin) {
|
||||
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
|
||||
}
|
||||
size_t num_bytes = ggml_nbytes(cur);
|
||||
if (ggml_backend_buft_is_host(buft)) {
|
||||
// for the CPU and Metal backend, we can read directly into the tensor
|
||||
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
|
||||
} else {
|
||||
// read into a temporary buffer first, then copy to device memory
|
||||
read_buf.resize(num_bytes);
|
||||
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
|
||||
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
|
||||
}
|
||||
data_loaded += num_bytes;
|
||||
if (progress_callback && total_data_size > 0) {
|
||||
const float progress = (float)data_loaded / (float)total_data_size;
|
||||
if (!progress_callback(progress, progress_callback_user_data)) {
|
||||
throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__));
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
|
||||
} else {
|
||||
LOG_DBG("%s: no_alloc is set, skipping tensor data loading (%zu tensors)\n", __func__, tensors_to_load.size());
|
||||
}
|
||||
fin.close();
|
||||
|
||||
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3105,7 +3137,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
|
||||
clip_ctx * ctx_audio = nullptr;
|
||||
|
||||
try {
|
||||
clip_model_loader loader(fname);
|
||||
clip_model_loader loader(fname,
|
||||
/* skip_tensors */ false,
|
||||
ctx_params.progress_callback,
|
||||
ctx_params.progress_callback_user_data);
|
||||
bool skip_audio = false;
|
||||
|
||||
if (loader.has_vision) {
|
||||
|
||||
@@ -54,6 +54,8 @@ struct clip_context_params {
|
||||
ggml_backend_sched_eval_callback cb_eval;
|
||||
void * cb_eval_user_data;
|
||||
bool no_alloc;
|
||||
mtmd_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
struct clip_init_result {
|
||||
|
||||
@@ -396,6 +396,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
|
||||
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
// Ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
|
||||
@@ -582,13 +582,29 @@ mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx,
|
||||
}
|
||||
|
||||
mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) {
|
||||
std::vector<unsigned char> buf;
|
||||
#ifdef _WIN32
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
|
||||
if (!wlen) {
|
||||
LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname);
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
std::vector<wchar_t> wfname(wlen);
|
||||
wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wfname.data(), wlen);
|
||||
if (!wlen) {
|
||||
LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname);
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
FILE * f = _wfopen(wfname.data(), L"rb");
|
||||
#else
|
||||
FILE * f = fopen(fname, "rb");
|
||||
#endif
|
||||
if (!f) {
|
||||
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
std::vector<unsigned char> buf;
|
||||
|
||||
fseek(f, 0, SEEK_END);
|
||||
long file_size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
|
||||
+8
-1
@@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() {
|
||||
/* cb_eval */ nullptr,
|
||||
/* cb_eval_user_data */ nullptr,
|
||||
/* batch_max_tokens */ 1024,
|
||||
/* progress_callback */ nullptr,
|
||||
/* progress_callback_user_data */ nullptr,
|
||||
};
|
||||
return params;
|
||||
}
|
||||
@@ -345,6 +347,8 @@ struct mtmd_context {
|
||||
/* cb_eval */ ctx_params.cb_eval,
|
||||
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
|
||||
/* no_alloc */ no_alloc,
|
||||
/* progress_callback */ ctx_params.progress_callback,
|
||||
/* progress_callback_user_data */ ctx_params.progress_callback_user_data,
|
||||
};
|
||||
|
||||
auto res = clip_init(mmproj_fname, ctx_clip_params);
|
||||
@@ -2133,9 +2137,12 @@ std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(const char * mmproj_f
|
||||
mtmd::context_ptr ctx;
|
||||
auto saved_log_callback = g_logger_state.log_callback;
|
||||
auto saved_log_user_data = g_logger_state.log_callback_user_data;
|
||||
|
||||
ctx_params.progress_callback = nullptr;
|
||||
|
||||
try {
|
||||
mtmd_log_set(stub_log_callback, nullptr); // suppress logging
|
||||
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params));
|
||||
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params, true));
|
||||
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
|
||||
std::map<ggml_backend_dev_t, size_t> total_mem;
|
||||
auto merge = [&](const struct clip_ctx * c) {
|
||||
|
||||
@@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks;
|
||||
typedef struct mtmd_input_text mtmd_input_text;
|
||||
typedef struct mtmd_batch mtmd_batch;
|
||||
|
||||
typedef bool (*mtmd_progress_callback)(float progress, void * user_data);
|
||||
|
||||
struct mtmd_context_params {
|
||||
bool use_gpu;
|
||||
bool print_timings;
|
||||
@@ -104,6 +106,12 @@ struct mtmd_context_params {
|
||||
int32_t batch_max_tokens; // maximum number of output tokens in a batch
|
||||
// (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit)
|
||||
// (default: 1024)
|
||||
|
||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||
// If the provided progress_callback returns true, model loading continues.
|
||||
// If it returns false, model loading is immediately aborted.
|
||||
mtmd_progress_callback progress_callback;
|
||||
void * progress_callback_user_data;
|
||||
};
|
||||
|
||||
MTMD_API const char * mtmd_default_marker(void);
|
||||
|
||||
@@ -180,6 +180,17 @@ That requires `JSON.stringify` when formatted to message content:
|
||||
}
|
||||
```
|
||||
|
||||
### Router mode: how child <--> router communicates
|
||||
|
||||
Upon spawning a new child process using `subprocess`, both child and router listen to the stdout/stderr (combined)
|
||||
|
||||
For the direction from child to router:
|
||||
- Generic messages are logs, it will be forwarded to router's stdout
|
||||
- Special state update messages are prefixed by `cmd_child_to_router:state:`, followed by a JSON. See `server_models::handle_child_state` for more
|
||||
|
||||
For the direction from router to child:
|
||||
- When server sends `cmd_router_to_child:exit`, the child should exit gracefully --> if after `DEFAULT_STOP_TIMEOUT` and the child is still running, force-kill it
|
||||
|
||||
### Model management API (router mode)
|
||||
|
||||
Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976)
|
||||
@@ -193,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will
|
||||
|
||||
The flow for downloading a new model:
|
||||
- POST request comes in --> `post_router_models` --> validation
|
||||
- `server_models::download()` is called
|
||||
- Sets up a new thread `inst.th` and runs the download inside
|
||||
- If a stop request comes in, set `stop_download` to `true`
|
||||
- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD`
|
||||
- Child process runs the download and report status back to router via stdin/out
|
||||
- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process)
|
||||
- Otherwise, upon completion, we call `load_models()` to refresh the list of models
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
+43
-8
@@ -198,7 +198,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||
| `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)<br/>(env: LLAMA_API_KEY) |
|
||||
| `--api-key-file FNAME` | path to file containing API keys (default: none)<br/>(env: LLAMA_ARG_API_KEY_FILE) |
|
||||
| `--api-key-file FNAME` | path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)<br/>(env: LLAMA_ARG_API_KEY_FILE) |
|
||||
| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
|
||||
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
|
||||
| `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) |
|
||||
@@ -1230,8 +1230,6 @@ print(completion.choices[0].text)
|
||||
|
||||
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
|
||||
|
||||
If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more.
|
||||
|
||||
*Options:*
|
||||
|
||||
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported.
|
||||
@@ -1250,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
|
||||
|
||||
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
|
||||
|
||||
For multimodal input:
|
||||
- Content type `image_url` and `input_audio` are the same as OAI schema
|
||||
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
|
||||
For multimodal input (typed content, `messages[i].content[j]`):
|
||||
- If `type == "image_url"`:
|
||||
- `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file
|
||||
- Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...)
|
||||
- If `type == "input_audio"`:
|
||||
- Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `miniaudio` (mp3, wav, flac)
|
||||
- `input_audio.format` will be ignored, the file format will be determined automatically
|
||||
- If `type == "input_video"`:
|
||||
- Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file
|
||||
- Accepts formats supported by `ffmpeg`
|
||||
- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://`
|
||||
|
||||
*Examples:*
|
||||
|
||||
@@ -1859,9 +1866,37 @@ Example events:
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "download_finished",
|
||||
"event": "model_status",
|
||||
"data": {
|
||||
"status": "loading"
|
||||
"status": "loading",
|
||||
"progress": {
|
||||
"stages": ["text_model", "spec_model", "mmproj_model"],
|
||||
"current": "text_model",
|
||||
"value": 0.5
|
||||
}
|
||||
}
|
||||
}
|
||||
// note for "loading" status:
|
||||
// - subsequent events will follow the same order of "stages" list
|
||||
// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "model_status",
|
||||
"data": {
|
||||
"status": "loaded",
|
||||
"info": {
|
||||
// note: only include info on first load
|
||||
// waking up from sleep doesn't have this
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
"model": "...",
|
||||
"event": "model_status",
|
||||
"data": {
|
||||
"status": "sleeping"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
|
||||
json format_error_response(const std::string & message, const enum error_type type) {
|
||||
std::string type_str;
|
||||
@@ -816,12 +817,21 @@ json oaicompat_completion_params_parse(const json & body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
// media_path always end with '/', see arg.cpp
|
||||
// url can be
|
||||
// - http(s):// for remote files
|
||||
// - file:// for local files (only allowed if media_path is set)
|
||||
// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...)
|
||||
// - raw base64 encoded data
|
||||
static void handle_media(
|
||||
std::vector<raw_buffer> & out_files,
|
||||
json & media_obj,
|
||||
const std::string & media_path) {
|
||||
std::string url = json_value(media_obj, "url", std::string());
|
||||
const std::string & url,
|
||||
const std::string & media_path,
|
||||
bool accept_base64_uri) {
|
||||
if (!media_path.empty()) {
|
||||
// should already be enforced by arg.cpp, but checking just in case
|
||||
GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR);
|
||||
}
|
||||
|
||||
if (string_starts_with(url, "http")) {
|
||||
// download remote image
|
||||
// TODO @ngxson : maybe make these params configurable
|
||||
@@ -857,20 +867,28 @@ static void handle_media(
|
||||
data.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
out_files.push_back(data);
|
||||
|
||||
} else {
|
||||
} else if (accept_base64_uri && string_starts_with(url, "data:")) {
|
||||
// try to decode base64 image
|
||||
std::vector<std::string> parts = string_split<std::string>(url, /*separator*/ ',');
|
||||
if (parts.size() != 2) {
|
||||
throw std::runtime_error("Invalid url value");
|
||||
throw std::runtime_error("Invalid uri-encoded base64 value");
|
||||
} else if (!string_starts_with(parts[0], "data:image/")) {
|
||||
throw std::runtime_error("Invalid url format: " + parts[0]);
|
||||
throw std::runtime_error("Invalid uri format: " + parts[0]);
|
||||
} else if (!string_ends_with(parts[0], "base64")) {
|
||||
throw std::runtime_error("url must be base64 encoded");
|
||||
throw std::runtime_error("uri must be base64 encoded");
|
||||
} else {
|
||||
auto base64_data = parts[1];
|
||||
auto decoded_data = base64_decode(base64_data);
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
|
||||
} else {
|
||||
// try as raw base64 string
|
||||
auto decoded_data = base64_decode(url);
|
||||
if (decoded_data.empty()) {
|
||||
throw std::runtime_error("Invalid base64 value");
|
||||
}
|
||||
out_files.push_back(decoded_data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -956,14 +974,15 @@ json oaicompat_chat_params_parse(
|
||||
}
|
||||
|
||||
for (auto & p : content) {
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
std::string type = json_value(p, "type", std::string());
|
||||
if (type == "image_url") {
|
||||
if (!opt.allow_image) {
|
||||
throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
handle_media(out_files, image_url, opt.media_path);
|
||||
std::string url = json_value(image_url, "url", std::string());
|
||||
handle_media(out_files, url, opt.media_path, true);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -974,17 +993,11 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string data = json_value(input_audio, "data", std::string());
|
||||
std::string format = json_value(input_audio, "format", std::string());
|
||||
// while we also support flac, we don't allow it here so we matches the OAI spec
|
||||
if (format != "wav" && format != "mp3") {
|
||||
throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'");
|
||||
}
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
// TODO: add audio_url support by reusing handle_media()
|
||||
// note: don't need to validate "format", it's redundant
|
||||
json input_audio = json_value(p, "input_audio", json::object());
|
||||
std::string url = json_value(input_audio, "data",
|
||||
json_value(input_audio, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -995,10 +1008,10 @@ json oaicompat_chat_params_parse(
|
||||
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string data = json_value(input_video, "data", std::string());
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string url = json_value(input_video, "data",
|
||||
json_value(input_video, "url", std::string()));
|
||||
handle_media(out_files, url, opt.media_path, false);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
@@ -1238,7 +1251,7 @@ json format_response_rerank(
|
||||
// other utils
|
||||
//
|
||||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, size_t n_top) {
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
@@ -1257,21 +1270,34 @@ std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int i
|
||||
}
|
||||
}
|
||||
|
||||
// sort tokens by logits
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
// sort tokens by logits (partial: only the leading `n_top` need ordering)
|
||||
if (n_top > cur.size()) {
|
||||
n_top = cur.size();
|
||||
}
|
||||
if (n_top > 0) {
|
||||
std::partial_sort(cur.begin(), cur.begin() + n_top, cur.end(),
|
||||
[](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
}
|
||||
|
||||
// apply softmax
|
||||
float max_l = cur[0].logit;
|
||||
float max_l = -std::numeric_limits<float>::infinity();
|
||||
if (n_top > 0) {
|
||||
max_l = cur[0].logit; // partial_sort guarantees the absolute maximum is at index 0
|
||||
} else {
|
||||
for (const auto & t : cur) {
|
||||
max_l = std::max(max_l, t.logit);
|
||||
}
|
||||
}
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
float p = expf(cur[i].logit - max_l);
|
||||
cur[i].p = p;
|
||||
for (auto & t : cur) {
|
||||
float p = expf(t.logit - max_l);
|
||||
t.p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur[i].p /= cum_sum;
|
||||
for (auto & t : cur) {
|
||||
t.p /= cum_sum;
|
||||
}
|
||||
|
||||
return cur;
|
||||
|
||||
@@ -326,7 +326,7 @@ json format_response_rerank(
|
||||
// other utils
|
||||
//
|
||||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx);
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, size_t n_top);
|
||||
|
||||
std::string safe_json_to_str(const json & data);
|
||||
|
||||
|
||||
+694
-385
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,7 @@ struct server_context_meta {
|
||||
bool has_inp_image;
|
||||
bool has_inp_audio;
|
||||
bool has_inp_video;
|
||||
json json_ui_settings; // Primary: new name
|
||||
json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat)
|
||||
json json_ui_settings;
|
||||
int slot_n_ctx;
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
@@ -53,6 +52,33 @@ struct server_context_meta {
|
||||
uint64_t model_size;
|
||||
};
|
||||
|
||||
enum server_state {
|
||||
SERVER_STATE_DOWNLOADING,
|
||||
SERVER_STATE_LOADING,
|
||||
SERVER_STATE_READY,
|
||||
SERVER_STATE_SLEEPING,
|
||||
};
|
||||
|
||||
static std::string server_state_to_str(server_state state) {
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING: return "downloading";
|
||||
case SERVER_STATE_LOADING: return "loading";
|
||||
case SERVER_STATE_READY: return "ready";
|
||||
case SERVER_STATE_SLEEPING: return "sleeping";
|
||||
default: GGML_ASSERT(false && "invalid server_state");
|
||||
}
|
||||
}
|
||||
|
||||
static server_state server_state_from_str(const std::string & str) {
|
||||
if (str == "downloading") return SERVER_STATE_DOWNLOADING;
|
||||
if (str == "loading") return SERVER_STATE_LOADING;
|
||||
if (str == "ready") return SERVER_STATE_READY;
|
||||
if (str == "sleeping") return SERVER_STATE_SLEEPING;
|
||||
GGML_ASSERT(false && "invalid server_state string");
|
||||
}
|
||||
|
||||
using server_state_callback_t = std::function<void(server_state, json /* payload */)>;
|
||||
|
||||
struct server_context {
|
||||
std::unique_ptr<server_context_impl> impl;
|
||||
|
||||
@@ -80,9 +106,8 @@ struct server_context {
|
||||
// not thread-safe, should only be used from the main thread
|
||||
server_context_meta get_meta() const;
|
||||
|
||||
// register a callback to be called when sleeping state changes
|
||||
// must be set before load_model() is called
|
||||
void on_sleeping_changed(std::function<void(bool)> callback);
|
||||
// note: must be set before load_model() is called
|
||||
void set_state_callback(server_state_callback_t callback);
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,18 @@
|
||||
#include <unordered_set>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
|
||||
#include "server-http.h"
|
||||
|
||||
static std::string proxy_header_to_lower(std::string header) {
|
||||
std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) {
|
||||
return std::tolower(c);
|
||||
});
|
||||
return header;
|
||||
}
|
||||
|
||||
static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) {
|
||||
std::string target_url = req.get_param("url");
|
||||
common_http_url parsed_url = common_http_parse_url(target_url);
|
||||
@@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||
SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str());
|
||||
|
||||
std::map<std::string, std::string> headers;
|
||||
const std::string proxy_header_prefix = "x-llama-server-proxy-header-";
|
||||
for (auto [key, value] : req.headers) {
|
||||
auto new_key = key;
|
||||
if (string_starts_with(new_key, "x-proxy-header-")) {
|
||||
string_replace_all(new_key, "x-proxy-header-", "");
|
||||
const std::string lowered_key = proxy_header_to_lower(key);
|
||||
if (!string_starts_with(lowered_key, proxy_header_prefix)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_key = key.substr(proxy_header_prefix.size());
|
||||
if (new_key.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
headers[new_key] = value;
|
||||
}
|
||||
|
||||
|
||||
+309
-177
@@ -1,5 +1,6 @@
|
||||
#include "server-common.h"
|
||||
#include "server-models.h"
|
||||
#include "server-context.h"
|
||||
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
@@ -44,9 +45,7 @@ extern char **environ;
|
||||
#define DEFAULT_STOP_TIMEOUT 10 // seconds
|
||||
|
||||
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
|
||||
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep
|
||||
#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep"
|
||||
#define CMD_CHILD_TO_ROUTER_INFO "cmd_child_to_router:info:" // followed by json string
|
||||
#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string
|
||||
|
||||
// address for child process, this is needed because router may run on 0.0.0.0
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
|
||||
@@ -65,6 +64,17 @@ struct server_subproc {
|
||||
return sproc.has_value() && subprocess_alive(&sproc.value());
|
||||
}
|
||||
|
||||
void request_exit() {
|
||||
if (sproc.has_value()) {
|
||||
FILE * stdin_file = subprocess_stdin(&sproc.value());
|
||||
if (stdin_file) {
|
||||
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
|
||||
fflush(stdin_file);
|
||||
}
|
||||
}
|
||||
stopped.store(true, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void terminate() {
|
||||
if (!sproc.has_value()) {
|
||||
return;
|
||||
@@ -324,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo
|
||||
}
|
||||
|
||||
void server_models::load_models() {
|
||||
// Phase 1: load presets from all sources — pure I/O, no lock needed
|
||||
// Phase 1: load presets from all sources - pure I/O, no lock needed
|
||||
// 1. cached models
|
||||
common_presets cached_models = ctx_preset.load_from_cache();
|
||||
SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
|
||||
@@ -377,7 +387,7 @@ void server_models::load_models() {
|
||||
return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET;
|
||||
};
|
||||
|
||||
// Helpers that read `mapping` — must be called while holding the lock.
|
||||
// Helpers that read `mapping` - must be called while holding the lock.
|
||||
std::unordered_set<std::string> custom_names;
|
||||
for (const auto & [name, preset] : custom_presets) custom_names.insert(name);
|
||||
auto join_set = [](const std::set<std::string> & s) {
|
||||
@@ -443,6 +453,7 @@ void server_models::load_models() {
|
||||
/* last_used */ 0,
|
||||
/* args */ std::vector<std::string>(),
|
||||
/* loaded_info */ {},
|
||||
/* progress */ {},
|
||||
/* exit_code */ 0,
|
||||
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
||||
/* multimodal */ mtmd_caps{false, false},
|
||||
@@ -523,7 +534,7 @@ void server_models::load_models() {
|
||||
}
|
||||
}
|
||||
|
||||
// join outside the lock — monitoring thread calls update_status (needs lock)
|
||||
// join outside the lock - monitoring thread calls update_status (needs lock)
|
||||
lk.unlock();
|
||||
for (auto & th : threads_to_join) th.join();
|
||||
lk.lock();
|
||||
@@ -609,6 +620,7 @@ void server_models::load_models() {
|
||||
/* last_used */ 0,
|
||||
/* args */ std::vector<std::string>(),
|
||||
/* loaded_info */ {},
|
||||
/* progress */ {},
|
||||
/* exit_code */ 0,
|
||||
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
|
||||
/* multimodal */ mtmd_caps{false, false},
|
||||
@@ -621,7 +633,7 @@ void server_models::load_models() {
|
||||
|
||||
apply_stop_timeout();
|
||||
|
||||
// clear reload flag before unlocking for autoload — load() blocks on !is_reloading,
|
||||
// clear reload flag before unlocking for autoload - load() blocks on !is_reloading,
|
||||
// so clearing it here (while still locked) prevents a deadlock in the autoload calls below
|
||||
is_reloading = false;
|
||||
cv.notify_all();
|
||||
@@ -814,17 +826,23 @@ void server_models::unload_lru() {
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
load(name, load_options{});
|
||||
}
|
||||
|
||||
void server_models::load(const std::string & name, const load_options & opts) {
|
||||
if (!opts.custom_meta.has_value()) {
|
||||
if (!has_model(name)) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
unload_lru();
|
||||
}
|
||||
unload_lru();
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// edge case: block until any in-progress reload has finished so we always load
|
||||
// against the freshest preset and a consistent mapping state
|
||||
cv.wait(lk, [this]() { return !is_reloading; });
|
||||
|
||||
auto meta = mapping[name].meta;
|
||||
auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta;
|
||||
if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
|
||||
SRV_INF("model %s is not ready\n", name.c_str());
|
||||
return;
|
||||
@@ -868,6 +886,12 @@ void server_models::load(const std::string & name) {
|
||||
std::vector<std::string> child_env = base_env; // copy
|
||||
child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
|
||||
|
||||
if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
child_env.push_back("LLAMA_SERVER_CHILD_MODE=download");
|
||||
child_env.push_back("LLAMA_ARG_HF_REPO=" + name);
|
||||
}
|
||||
|
||||
SRV_INF("%s", "spawning server instance with args:\n");
|
||||
for (const auto & arg : child_args) {
|
||||
SRV_INF(" %s\n", arg.c_str());
|
||||
@@ -885,13 +909,17 @@ void server_models::load(const std::string & name) {
|
||||
if (result != 0) {
|
||||
throw std::runtime_error("failed to spawn server instance");
|
||||
}
|
||||
|
||||
inst.stdin_file = subprocess_stdin(&inst.subproc->get());
|
||||
}
|
||||
|
||||
// start a thread to manage the child process
|
||||
// captured variables are guaranteed to be destroyed only after the thread is joined
|
||||
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
|
||||
inst.th = std::thread([
|
||||
this, name,
|
||||
child_proc = inst.subproc,
|
||||
port = inst.meta.port,
|
||||
stop_timeout = inst.meta.stop_timeout,
|
||||
child_mode = opts.mode
|
||||
]() {
|
||||
FILE * stdin_file = subprocess_stdin(&child_proc->get());
|
||||
FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr
|
||||
|
||||
@@ -904,12 +932,8 @@ void server_models::load(const std::string & name) {
|
||||
while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) {
|
||||
LOG("[%5d] %s", port, buffer);
|
||||
std::string str(buffer);
|
||||
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) {
|
||||
this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
|
||||
} else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_INFO)) {
|
||||
this->update_loaded_info(name, str);
|
||||
} else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) {
|
||||
this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0);
|
||||
if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) {
|
||||
this->handle_child_state(name, str);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -928,7 +952,7 @@ void server_models::load(const std::string & name) {
|
||||
return is_stopping() || child_proc->stopped.load(std::memory_order_acquire);
|
||||
});
|
||||
}
|
||||
// child crashed or finished on its own — skip graceful shutdown sequence
|
||||
// child crashed or finished on its own, skip graceful shutdown sequence
|
||||
if (child_proc->stopped.load(std::memory_order_acquire)) {
|
||||
return;
|
||||
}
|
||||
@@ -976,7 +1000,14 @@ void server_models::load(const std::string & name) {
|
||||
subprocess_destroy(&child_proc->get());
|
||||
|
||||
// update status and exit code
|
||||
this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
|
||||
if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
// instance will be cleaned up on next load_models() call
|
||||
} else {
|
||||
this->update_status(name, {
|
||||
SERVER_MODEL_STATUS_UNLOADED,
|
||||
exit_code
|
||||
});
|
||||
}
|
||||
SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
|
||||
});
|
||||
|
||||
@@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) {
|
||||
{
|
||||
auto & old_instance = mapping[name];
|
||||
// old process should have exited already, but just in case, we clean it up here
|
||||
if (old_instance.subproc->is_alive()) {
|
||||
if (old_instance.subproc && old_instance.subproc->is_alive()) {
|
||||
SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
|
||||
old_instance.subproc->terminate(); // force kill
|
||||
}
|
||||
@@ -1001,90 +1032,13 @@ void server_models::load(const std::string & name) {
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
// callback for model downloading functionality
|
||||
struct server_models_download_res : public common_download_callback {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
std::function<bool()> should_stop;
|
||||
std::function<void(const common_download_progress & p)> on_progress;
|
||||
|
||||
bool is_ok = false;
|
||||
|
||||
bool run() {
|
||||
try {
|
||||
common_download_model(model, opts);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_done(const common_download_progress &, bool ok) override {
|
||||
is_ok = ok;
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop();
|
||||
}
|
||||
};
|
||||
|
||||
void server_models::download(common_params_model && model, common_download_opts && opts) {
|
||||
std::string name = model.name;
|
||||
GGML_ASSERT(name == model.hf_repo);
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
if (mapping.find(name) != mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " already exists");
|
||||
}
|
||||
|
||||
instance_t inst;
|
||||
inst.meta.name = name;
|
||||
inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
inst.subproc = std::make_shared<server_subproc>();
|
||||
|
||||
auto dl = std::make_unique<server_models_download_res>();
|
||||
dl->model = model; // copy
|
||||
dl->opts = opts; // copy
|
||||
|
||||
dl->should_stop = [sp = inst.subproc]() {
|
||||
return sp->stopped.load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
dl->on_progress = [this, name](const common_download_progress & p) {
|
||||
update_download_progress(name, p, false);
|
||||
};
|
||||
|
||||
inst.th = std::thread([this, dl = std::move(dl)]() {
|
||||
dl->opts.callback = dl.get();
|
||||
bool ok = dl->run();
|
||||
SRV_INF("download finished for model name=%s with status=%s\n",
|
||||
dl->model.name.c_str(), ok ? "success" : "failure");
|
||||
update_download_progress(dl->model.name, {}, true, ok);
|
||||
// need_reload is set inside update_download_progress under the mutex;
|
||||
// the next load_models() call will clean up this instance
|
||||
});
|
||||
|
||||
mapping[name] = std::move(inst);
|
||||
notify_sse("status_update", name, {
|
||||
{"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)},
|
||||
});
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::unload(const std::string & name) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->stopped.store(true, std::memory_order_relaxed);
|
||||
it->second.subproc->request_exit();
|
||||
// for convenience, we wait the status change here
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING;
|
||||
@@ -1130,21 +1084,33 @@ void server_models::unload_all() {
|
||||
}
|
||||
}
|
||||
|
||||
void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
|
||||
void server_models::update_status(const std::string & name, const update_status_args & args) {
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
auto & meta = it->second.meta;
|
||||
meta.status = status;
|
||||
meta.exit_code = exit_code;
|
||||
meta.status = args.status;
|
||||
meta.exit_code = args.exit_code;
|
||||
if (!args.loaded_info.is_null()) {
|
||||
meta.loaded_info = args.loaded_info;
|
||||
}
|
||||
if (!args.progress.is_null()) {
|
||||
meta.progress = args.progress;
|
||||
}
|
||||
}
|
||||
// broadcast status change to SSE
|
||||
{
|
||||
json data = {
|
||||
{"status", server_model_status_to_string(status)},
|
||||
{"status", server_model_status_to_string(args.status)},
|
||||
};
|
||||
if (status == SERVER_MODEL_STATUS_UNLOADED) {
|
||||
data["exit_code"] = exit_code;
|
||||
if (args.status == SERVER_MODEL_STATUS_UNLOADED) {
|
||||
data["exit_code"] = args.exit_code;
|
||||
}
|
||||
if (!args.loaded_info.is_null()) {
|
||||
data["info"] = args.loaded_info;
|
||||
}
|
||||
if (!args.progress.is_null()) {
|
||||
data["progress"] = args.progress;
|
||||
}
|
||||
// note: notify_sse doesn't acquire the lock, so no deadlock here
|
||||
notify_sse("status_change", name, data);
|
||||
@@ -1152,29 +1118,6 @@ void server_models::update_status(const std::string & name, server_model_status
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::update_loaded_info(const std::string & name, std::string & raw_info) {
|
||||
if (!string_starts_with(raw_info, CMD_CHILD_TO_ROUTER_INFO)) {
|
||||
SRV_WRN("invalid loaded info format from child for model name=%s: %s\n", name.c_str(), raw_info.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
json info;
|
||||
try {
|
||||
info = json::parse(raw_info.substr(strlen(CMD_CHILD_TO_ROUTER_INFO)));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_WRN("failed to parse loaded info from child for model name=%s: %s\n", name.c_str(), e.what());
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
auto & meta = it->second.meta;
|
||||
meta.loaded_info = info;
|
||||
}
|
||||
cv.notify_all();
|
||||
}
|
||||
|
||||
void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) {
|
||||
json curr;
|
||||
{
|
||||
@@ -1207,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com
|
||||
}
|
||||
|
||||
bool server_models::remove(const std::string & name) {
|
||||
auto meta = get_meta(name);
|
||||
// do everything under one lock acquisition; avoid get_meta() /
|
||||
// unload() because they can trigger load_models() which erases
|
||||
// transient DOWNLOADING / DOWNLOADED entries as a side-effect
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
|
||||
if (!meta.has_value()) {
|
||||
auto it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
throw std::runtime_error("model name=" + name + " is not found");
|
||||
}
|
||||
if (meta->source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) {
|
||||
throw std::runtime_error("model name=" + name + " is not removable (not from cache)");
|
||||
}
|
||||
|
||||
unload(name); // cancel download or stop running instance
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(mutex);
|
||||
// a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED
|
||||
wait(lk, name, [](const server_model_meta & new_meta) {
|
||||
return new_meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
// join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no
|
||||
// longer acquires this mutex, so joining while holding it is safe
|
||||
if (mapping[name].th.joinable()) {
|
||||
mapping[name].th.join();
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
// cancel in-flight download
|
||||
SRV_INF("cancelling download for model name=%s\n", name.c_str());
|
||||
it->second.subproc->request_exit();
|
||||
} else if (it->second.meta.is_running()) {
|
||||
// stop running instance
|
||||
SRV_INF("stopping model instance name=%s\n", name.c_str());
|
||||
stopping_models.insert(name);
|
||||
if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) {
|
||||
it->second.subproc->terminate();
|
||||
}
|
||||
// remove the model from disk (hold lock to prevent concurrent load)
|
||||
bool ok = common_download_remove(name);
|
||||
if (ok) {
|
||||
mapping.erase(name);
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed");
|
||||
notify_sse("model_remove", name, {});
|
||||
return ok;
|
||||
cv_stop.notify_all();
|
||||
}
|
||||
|
||||
// wait until the monitoring thread finishes
|
||||
wait(lk, name, [](const server_model_meta & meta) {
|
||||
return meta.status == SERVER_MODEL_STATUS_UNLOADED
|
||||
|| meta.status == SERVER_MODEL_STATUS_DOWNLOADED;
|
||||
});
|
||||
|
||||
// re-find after wait - load_models() may have erased the entry during the wait
|
||||
it = mapping.find(name);
|
||||
if (it == mapping.end()) {
|
||||
// load_models() already joined the thread and erased the entry;
|
||||
// we just need to clean up the cached files on disk
|
||||
lk.unlock();
|
||||
bool ok = common_download_remove(name);
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
// join before erasing - thread no longer acquires this mutex
|
||||
if (it->second.th.joinable()) {
|
||||
it->second.th.join();
|
||||
}
|
||||
|
||||
// remove from disk (best-effort: cancelled downloads may have no cached files)
|
||||
bool ok = common_download_remove(name);
|
||||
mapping.erase(name);
|
||||
if (!ok) {
|
||||
SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str());
|
||||
}
|
||||
SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial");
|
||||
notify_sse("model_remove", name, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
void server_models::wait(const std::string & name, std::function<bool(const server_model_meta &)> predicate) {
|
||||
@@ -1252,7 +1223,9 @@ void server_models::wait(std::unique_lock<std::mutex> & lk, const std::string &
|
||||
return predicate(it->second.meta);
|
||||
|
||||
}
|
||||
return false;
|
||||
// model was removed from mapping by another code path (e.g. load_models()).
|
||||
// nothing left to wait for - tell the caller to proceed.
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1323,21 +1296,168 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
|
||||
return proxy;
|
||||
}
|
||||
|
||||
bool server_models::is_child_server() {
|
||||
void server_models::handle_child_state(const std::string & name, const std::string & raw_input) {
|
||||
server_state state;
|
||||
json payload;
|
||||
|
||||
try {
|
||||
json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE)));
|
||||
state = server_state_from_str(json_value(data, "state", std::string()));
|
||||
payload = json_value(data, "payload", json{});
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what());
|
||||
return;
|
||||
}
|
||||
|
||||
switch (state) {
|
||||
case SERVER_STATE_DOWNLOADING:
|
||||
{
|
||||
std::string result = json_value(payload, "result", std::string());
|
||||
std::string url = json_value(payload, "url", std::string());
|
||||
auto request_exit = [&]() {
|
||||
std::lock_guard<std::mutex> lk(mutex);
|
||||
auto it = mapping.find(name);
|
||||
if (it != mapping.end()) {
|
||||
return it->second.subproc->request_exit();
|
||||
}
|
||||
};
|
||||
if (result == "download_finished") {
|
||||
update_download_progress(name, {}, true, true);
|
||||
request_exit();
|
||||
} else if (result == "download_failed") {
|
||||
update_download_progress(name, {}, true, false);
|
||||
request_exit();
|
||||
} else if (!url.empty()) {
|
||||
common_download_progress p;
|
||||
p.url = url;
|
||||
p.downloaded = json_value(payload, "downloaded", (size_t)0);
|
||||
p.total = json_value(payload, "total", (size_t)0);
|
||||
update_download_progress(name, p, false);
|
||||
}
|
||||
} break;
|
||||
case SERVER_STATE_LOADING:
|
||||
{
|
||||
update_status(name, {
|
||||
SERVER_MODEL_STATUS_LOADING,
|
||||
0,
|
||||
nullptr, // no loaded_info yet
|
||||
payload,
|
||||
});
|
||||
} break;
|
||||
case SERVER_STATE_READY:
|
||||
{
|
||||
update_status(name, {
|
||||
SERVER_MODEL_STATUS_LOADED,
|
||||
0,
|
||||
// note: payload can be empty if this is a wakeup from sleep
|
||||
payload.size() > 0 ? payload : nullptr,
|
||||
{}, // reset progress info
|
||||
});
|
||||
} break;
|
||||
case SERVER_STATE_SLEEPING:
|
||||
{
|
||||
update_status(name, { SERVER_MODEL_STATUS_SLEEPING });
|
||||
} break;
|
||||
default:
|
||||
// should never happen, but just in case
|
||||
GGML_ASSERT(false && "unexpected state from child server");
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// server_child
|
||||
//
|
||||
|
||||
bool server_child::is_child() {
|
||||
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
|
||||
return router_port != nullptr;
|
||||
}
|
||||
|
||||
std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler, const json & model_info) {
|
||||
// send a notification to the router server that a model instance is ready
|
||||
common_log_pause(common_log_main());
|
||||
fflush(stdout);
|
||||
fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
|
||||
fflush(stdout);
|
||||
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_INFO, safe_json_to_str(model_info).c_str());
|
||||
fflush(stdout);
|
||||
common_log_resume(common_log_main());
|
||||
server_child_mode server_child::get_mode() {
|
||||
const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE");
|
||||
std::string mode_str(mode ? mode : "");
|
||||
if (mode_str == "download") {
|
||||
return SERVER_CHILD_MODE_DOWNLOAD;
|
||||
} else {
|
||||
return SERVER_CHILD_MODE_NORMAL;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_download_state : public common_download_callback {
|
||||
server_child * self;
|
||||
std::function<bool()> should_stop;
|
||||
std::atomic<int64_t> last_progress_time{0}; // multiple files downloading in different threads
|
||||
bool is_ok = false;
|
||||
|
||||
server_download_state(server_child * s) : self(s) {}
|
||||
|
||||
bool run(common_params & params) {
|
||||
try {
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this);
|
||||
is_ok = true;
|
||||
} catch (const std::exception & e) {
|
||||
auto model_name = params.model.get_name();
|
||||
SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what());
|
||||
is_ok = false;
|
||||
}
|
||||
return is_ok;
|
||||
}
|
||||
void on_progress(const common_download_progress & p) {
|
||||
json data = {
|
||||
{"url", p.url},
|
||||
{"downloaded", p.downloaded},
|
||||
{"total", p.total},
|
||||
};
|
||||
self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data);
|
||||
}
|
||||
void on_start(const common_download_progress & p) override {
|
||||
on_progress(p);
|
||||
}
|
||||
void on_update(const common_download_progress & p) override {
|
||||
int64_t now = ggml_time_ms();
|
||||
// throttle progress updates to avoid flooding logs
|
||||
if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) {
|
||||
on_progress(p);
|
||||
last_progress_time.store(now, std::memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
void on_done(const common_download_progress & p, bool) override {
|
||||
on_progress(p);
|
||||
}
|
||||
bool is_cancelled() const override {
|
||||
return should_stop ? should_stop() : false;
|
||||
}
|
||||
};
|
||||
|
||||
int server_child::run_download(common_params & params) {
|
||||
auto cancelled = std::make_shared<std::atomic<bool>>(false);
|
||||
|
||||
// monitor stdin for cancellation command from the router
|
||||
std::thread signal_thread = setup([cancelled](int) {
|
||||
cancelled->store(true, std::memory_order_relaxed);
|
||||
});
|
||||
|
||||
server_download_state dl(this);
|
||||
dl.should_stop = [cancelled]() {
|
||||
return cancelled->load(std::memory_order_relaxed);
|
||||
};
|
||||
|
||||
bool ok = dl.run(params);
|
||||
|
||||
notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), {
|
||||
{"result", ok ? "download_finished" : "download_failed"},
|
||||
});
|
||||
|
||||
// router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result
|
||||
if (signal_thread.joinable()) {
|
||||
signal_thread.join();
|
||||
}
|
||||
|
||||
SRV_INF("download completed %s\n", ok ? "successfully" : "with errors");
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::thread server_child::setup(const std::function<void(int)> & shutdown_handler) {
|
||||
// setup thread for monitoring stdin
|
||||
return std::thread([shutdown_handler]() {
|
||||
// wait for EOF on stdin
|
||||
@@ -1363,10 +1483,15 @@ std::thread server_models::setup_child_server(const std::function<void(int)> & s
|
||||
});
|
||||
}
|
||||
|
||||
void server_models::notify_router_sleeping_state(bool is_sleeping) {
|
||||
void server_child::notify_to_router(const std::string & state, const json & payload) {
|
||||
json data = {
|
||||
{"state", state},
|
||||
{"payload", payload},
|
||||
};
|
||||
std::lock_guard<std::mutex> lk(mtx_stdout);
|
||||
common_log_pause(common_log_main());
|
||||
fflush(stdout);
|
||||
fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY);
|
||||
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
|
||||
fflush(stdout);
|
||||
common_log_resume(common_log_main());
|
||||
}
|
||||
@@ -1474,7 +1599,6 @@ void server_models_routes::init_routes() {
|
||||
}},
|
||||
// New key
|
||||
{"ui_settings", ui_settings},
|
||||
{"webui_settings", webui_settings},
|
||||
{"build_info", std::string(llama_build_info())},
|
||||
{"cors_proxy_enabled", params.ui_mcp_proxy},
|
||||
});
|
||||
@@ -1606,7 +1730,7 @@ void server_models_routes::init_routes() {
|
||||
res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
if (!model->is_running()) {
|
||||
if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) {
|
||||
res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST));
|
||||
return res;
|
||||
}
|
||||
@@ -1645,11 +1769,11 @@ void server_models_routes::init_routes() {
|
||||
common_params_model model;
|
||||
common_download_opts opts;
|
||||
|
||||
model.name = name;
|
||||
model.hf_repo = name;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.download_mmproj = true;
|
||||
opts.download_mtp = true;
|
||||
// note: we only check main model, no need sidecar here
|
||||
opts.download_mmproj = false;
|
||||
opts.download_mtp = false;
|
||||
|
||||
// first, only check if the model is valid and can be downloaded
|
||||
opts.skip_download = true;
|
||||
@@ -1670,10 +1794,21 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model validation failed, unable to download");
|
||||
}
|
||||
|
||||
// reject if model already exists
|
||||
if (models.has_model(name)) {
|
||||
throw std::invalid_argument("model '" + name + "' already exists");
|
||||
}
|
||||
|
||||
// then, proceed with the actual download
|
||||
opts.skip_download = false;
|
||||
SRV_INF("starting download for model '%s'\n", name.c_str());
|
||||
models.download(std::move(model), std::move(opts));
|
||||
{
|
||||
server_models::load_options load_opts;
|
||||
load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD;
|
||||
load_opts.custom_meta = server_model_meta{};
|
||||
load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE;
|
||||
load_opts.custom_meta->name = name;
|
||||
models.load(name, load_opts);
|
||||
}
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
@@ -1687,10 +1822,7 @@ void server_models_routes::init_routes() {
|
||||
throw std::invalid_argument("model must be a non-empty string");
|
||||
}
|
||||
|
||||
bool ok = models.remove(name);
|
||||
if (!ok) {
|
||||
throw std::runtime_error("failed to remove model '" + name + "'");
|
||||
}
|
||||
models.remove(name); // throws on error
|
||||
|
||||
res_ok(res, {{"success", true}});
|
||||
return res;
|
||||
|
||||
@@ -40,6 +40,11 @@ enum server_model_source {
|
||||
SERVER_MODEL_SOURCE_CACHE,
|
||||
};
|
||||
|
||||
enum server_child_mode {
|
||||
SERVER_CHILD_MODE_NORMAL, // load the model and run normally
|
||||
SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit
|
||||
};
|
||||
|
||||
static std::string server_model_status_to_string(server_model_status status) {
|
||||
switch (status) {
|
||||
case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading";
|
||||
@@ -72,6 +77,7 @@ struct server_model_meta {
|
||||
int64_t last_used = 0; // for LRU unloading
|
||||
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
|
||||
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
|
||||
json progress; // reflect load or download progress info, if any
|
||||
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
|
||||
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
|
||||
mtmd_caps multimodal; // multimodal capabilities
|
||||
@@ -104,7 +110,6 @@ private:
|
||||
std::shared_ptr<server_subproc> subproc; // shared between main thread and monitoring thread
|
||||
std::thread th;
|
||||
server_model_meta meta;
|
||||
FILE * stdin_file = nullptr;
|
||||
};
|
||||
|
||||
std::mutex mutex;
|
||||
@@ -160,19 +165,28 @@ public:
|
||||
// return a copy of all model metadata (thread-safe)
|
||||
std::vector<server_model_meta> get_all_meta();
|
||||
|
||||
struct load_options {
|
||||
server_child_mode mode = SERVER_CHILD_MODE_NORMAL;
|
||||
// used for spawning a downloading child process
|
||||
std::optional<server_model_meta> custom_meta = std::nullopt;
|
||||
};
|
||||
|
||||
// load and unload model instances
|
||||
// these functions are thread-safe
|
||||
void load(const std::string & name);
|
||||
void load(const std::string & name, const load_options & opts);
|
||||
void unload(const std::string & name);
|
||||
void unload_all();
|
||||
|
||||
// download a new model, progress is reported via SSE
|
||||
// to stop the download, call unload()
|
||||
void download(common_params_model && model, common_download_opts && opts);
|
||||
|
||||
struct update_status_args {
|
||||
server_model_status status;
|
||||
int exit_code = 0; // only valid if status == UNLOADED
|
||||
json loaded_info = nullptr;
|
||||
json progress = nullptr;
|
||||
};
|
||||
// update the status of a model instance (thread-safe)
|
||||
void update_status(const std::string & name, server_model_status status, int exit_code);
|
||||
void update_loaded_info(const std::string & name, std::string & raw_info);
|
||||
// also send SSE notification to /models/sse endpoint
|
||||
void update_status(const std::string & name, const update_status_args & args);
|
||||
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
|
||||
|
||||
// remove a cache model from disk and update the list (thread-safe)
|
||||
@@ -193,21 +207,38 @@ public:
|
||||
// proxy an HTTP request to the model instance
|
||||
server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
|
||||
|
||||
// handle message sent from server_child::notify_to_router()
|
||||
// raw input must starts with CMD_CHILD_TO_ROUTER_STATE, followed by a JSON string
|
||||
// this function is not thread-safe, must be called from instance's monitoring thread
|
||||
// payload per state:
|
||||
// state = loading -> payload = {} (TODO: add progress info)
|
||||
// state = ready -> payload = model_info (json), or {} if wakeup from sleeping
|
||||
// state = sleeping -> payload = {}
|
||||
void handle_child_state(const std::string & name, const std::string & raw_input);
|
||||
};
|
||||
|
||||
struct server_child {
|
||||
// serializes the notify_to_router writes
|
||||
std::mutex mtx_stdout;
|
||||
std::atomic<bool> is_finished_downloading = false; // set by run_download
|
||||
|
||||
// return true if the current process is a child server instance
|
||||
static bool is_child_server();
|
||||
bool is_child();
|
||||
server_child_mode get_mode();
|
||||
int run_download(common_params & params);
|
||||
|
||||
// notify the router server that a model instance is ready
|
||||
// register the shutdown_handler to be called by the router
|
||||
// return the monitoring thread (to be joined by the caller)
|
||||
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler, const json & model_info);
|
||||
std::thread setup(const std::function<void(int)> & shutdown_handler);
|
||||
|
||||
// notify the router server that the sleeping state has changed
|
||||
static void notify_router_sleeping_state(bool sleeping);
|
||||
// notify router server for status changes (e.g. loading, downloading, sleeping, etc.)
|
||||
// message will be handled by server_models::handle_child_state() on the router side
|
||||
void notify_to_router(const std::string & state_name, const json & payload);
|
||||
};
|
||||
|
||||
struct server_models_routes {
|
||||
common_params params;
|
||||
json ui_settings = json::object(); // Primary: new name
|
||||
json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat)
|
||||
std::atomic<bool> stopping = false; // for graceful disconnecting SSE clients during shutdown
|
||||
server_models models;
|
||||
server_models_routes(const common_params & params, int argc, char ** argv)
|
||||
@@ -217,7 +248,6 @@ struct server_models_routes {
|
||||
try {
|
||||
json json_settings = json::parse(cfg);
|
||||
ui_settings = json_settings;
|
||||
webui_settings = json_settings; // Deprecated: keep in sync
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to parse UI config: %s\n", __func__, e.what());
|
||||
throw;
|
||||
|
||||
@@ -14,6 +14,9 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
|
||||
fields.emplace_back(f);
|
||||
};
|
||||
|
||||
add((new field_bool("verbose", params.verbose))
|
||||
->set_desc("Include __verbose field in the response with additional debug information"));
|
||||
|
||||
add((new field_bool("timings_per_token", params.timings_per_token))
|
||||
->set_desc("Include prompt processing and text generation speed information in each response"));
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
@@ -568,9 +569,13 @@ struct server_tool_edit_file : server_tool {
|
||||
}
|
||||
int n = (int) lines.size();
|
||||
if (e.line_start == -1) {
|
||||
// -1 means end of file; line_end is ignored — normalize to point past last line
|
||||
e.line_start = n + 1;
|
||||
e.line_end = n + 1;
|
||||
// -1 targets end of file -> valid for append only; line_end is ignored
|
||||
if (e.mode != "append") {
|
||||
return {{"error", "line_start -1 (end of file) is only valid for append mode"}};
|
||||
}
|
||||
// append at end of file: insert position is the current line count
|
||||
e.line_start = n;
|
||||
e.line_end = n;
|
||||
} else {
|
||||
if (e.line_start < 1 || e.line_end < e.line_start) {
|
||||
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
|
||||
@@ -611,8 +616,8 @@ struct server_tool_edit_file : server_tool {
|
||||
} else if (e.mode == "delete") {
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
} else { // append
|
||||
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
|
||||
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
|
||||
// insert after idx_end; idx_end + 1 == lines.size() for end-of-file append
|
||||
lines.insert(lines.begin() + (idx_end + 1), new_lines.begin(), new_lines.end());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+28
-12
@@ -90,8 +90,10 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// router server never loads a model and must not touch the GPU
|
||||
const bool is_router_server = params.model.path.empty()
|
||||
&& params.model.hf_repo.empty();
|
||||
|
||||
// skip device enumeration so the CUDA primary context stays uncreated
|
||||
const bool is_router_server = params.model.path.empty();
|
||||
common_params_print_info(params, !is_router_server);
|
||||
|
||||
if (!is_router_server) {
|
||||
@@ -113,8 +115,9 @@ int llama_server(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// for consistency between server router mode and single-model mode, we set the same model name as alias
|
||||
if (params.model_alias.empty() && !params.model.name.empty()) {
|
||||
params.model_alias.insert(params.model.name);
|
||||
auto model_name = params.model.get_name();
|
||||
if (params.model_alias.empty() && !model_name.empty()) {
|
||||
params.model_alias.insert(model_name);
|
||||
}
|
||||
|
||||
// struct that contains llama context and inference
|
||||
@@ -131,6 +134,7 @@ int llama_server(int argc, char ** argv) {
|
||||
//
|
||||
|
||||
// register API routes
|
||||
server_child child; // only used in non-router mode
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
@@ -251,6 +255,17 @@ int llama_server(int argc, char ** argv) {
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Handle downloading model
|
||||
//
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER);
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
@@ -300,15 +315,16 @@ int llama_server(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// load the model
|
||||
SRV_INF("%s", "loading model\n");
|
||||
|
||||
if (server_models::is_child_server()) {
|
||||
ctx_server.on_sleeping_changed([&](bool sleeping) {
|
||||
server_models::notify_router_sleeping_state(sleeping);
|
||||
// setup communication child --> router if necessary
|
||||
if (child.is_child()) {
|
||||
ctx_server.set_state_callback([&](server_state state, json payload) {
|
||||
child.notify_to_router(server_state_to_str(state), payload);
|
||||
});
|
||||
}
|
||||
|
||||
// load the model
|
||||
SRV_INF("%s", "loading model\n");
|
||||
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
if (ctx_http.thread.joinable()) {
|
||||
@@ -365,9 +381,9 @@ int llama_server(int argc, char ** argv) {
|
||||
|
||||
// optionally, notify router server that this instance is ready
|
||||
std::thread monitor_thread;
|
||||
if (server_models::is_child_server()) {
|
||||
json model_info = routes.get_model_info();
|
||||
monitor_thread = server_models::setup_child_server(shutdown_handler, model_info);
|
||||
if (child.is_child()) {
|
||||
monitor_thread = child.setup(shutdown_handler);
|
||||
child.notify_to_router(server_state_to_str(SERVER_STATE_READY), routes.get_model_info());
|
||||
}
|
||||
|
||||
// this call blocks the main thread until queue_tasks.terminate() is called
|
||||
|
||||
@@ -79,9 +79,9 @@ def test_load_split_model():
|
||||
assert match_regex("(little|girl)+", res.body["content"])
|
||||
|
||||
|
||||
def test_no_webui():
|
||||
def test_no_ui():
|
||||
global server
|
||||
# default: webui enabled
|
||||
# default: UI enabled
|
||||
server.start()
|
||||
url = f"http://{server.server_host}:{server.server_port}"
|
||||
res = requests.get(url)
|
||||
@@ -89,8 +89,8 @@ def test_no_webui():
|
||||
assert "<!doctype html>" in res.text
|
||||
server.stop()
|
||||
|
||||
# with --no-webui
|
||||
server.no_webui = True
|
||||
# with --no-ui, the UI should be disabled
|
||||
server.no_ui = True
|
||||
server.start()
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 404
|
||||
|
||||
@@ -603,3 +603,23 @@ def test_chat_completions_token_count():
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["input_tokens"] > 5
|
||||
|
||||
|
||||
def test_verbose_debug():
|
||||
global server
|
||||
server.start()
|
||||
for verbose in [True, False]:
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 2,
|
||||
"messages": [
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
],
|
||||
"verbose": verbose,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
if verbose:
|
||||
assert "__verbose" in res.body
|
||||
assert "Book" in res.body["__verbose"]["prompt"]
|
||||
else:
|
||||
assert "__verbose" not in res.body
|
||||
|
||||
@@ -12,7 +12,7 @@ def create_server():
|
||||
|
||||
def test_mcp_no_proxy():
|
||||
global server
|
||||
server.webui_mcp_proxy = False
|
||||
server.ui_mcp_proxy = False
|
||||
server.start()
|
||||
|
||||
res = server.make_request("GET", "/cors-proxy")
|
||||
@@ -21,7 +21,7 @@ def test_mcp_no_proxy():
|
||||
|
||||
def test_mcp_proxy():
|
||||
global server
|
||||
server.webui_mcp_proxy = True
|
||||
server.ui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com"
|
||||
@@ -32,7 +32,7 @@ def test_mcp_proxy():
|
||||
|
||||
def test_mcp_proxy_custom_port():
|
||||
global server
|
||||
server.webui_mcp_proxy = True
|
||||
server.ui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
# try getting the server's models API via the proxy
|
||||
|
||||
@@ -257,14 +257,25 @@ def test_router_reload_models():
|
||||
|
||||
|
||||
MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16"
|
||||
MODEL_DOWNLOAD_TIMEOUT = 300
|
||||
MODEL_DOWNLOAD_TIMEOUT = 30
|
||||
|
||||
|
||||
def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set."""
|
||||
def _listen_sse(
|
||||
server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None
|
||||
):
|
||||
"""Collect /models/sse events into `collected` until `stop` is set.
|
||||
|
||||
When `ready` is provided, it is set once the streaming response is open,
|
||||
i.e. the server has accepted the connection and registered us as a
|
||||
subscriber. Callers that trigger one-shot events (e.g. download_finished)
|
||||
must wait on `ready` before acting, otherwise the event can be broadcast
|
||||
before this client is subscribed and be lost.
|
||||
"""
|
||||
url = f"http://{server.server_host}:{server.server_port}/models/sse"
|
||||
try:
|
||||
with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp:
|
||||
if ready is not None:
|
||||
ready.set()
|
||||
for line_bytes in resp.iter_lines():
|
||||
if stop.is_set():
|
||||
break
|
||||
@@ -294,11 +305,17 @@ def test_router_download_model():
|
||||
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
sse_thread = threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
)
|
||||
sse_thread.start()
|
||||
|
||||
# wait for the SSE client to be subscribed before triggering the download,
|
||||
# otherwise the one-shot download_finished event can be broadcast before
|
||||
# this client is registered and be lost
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
|
||||
# Trigger the download
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
@@ -328,13 +345,17 @@ def test_router_delete_model():
|
||||
|
||||
# Ensure the model exists (download it if needed)
|
||||
if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False):
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
sse_events: list = []
|
||||
stop = threading.Event()
|
||||
sse_ready = threading.Event()
|
||||
threading.Thread(
|
||||
target=_listen_sse, args=(server, sse_events, stop), daemon=True
|
||||
target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True
|
||||
).start()
|
||||
# subscribe before triggering the download so the one-shot
|
||||
# download_finished event is not lost (see test_router_download_model)
|
||||
assert sse_ready.wait(10), "SSE client failed to connect"
|
||||
res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID})
|
||||
assert res.status_code == 200
|
||||
finished = _wait_for_sse_event(
|
||||
sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
from utils import *
|
||||
import threading
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
|
||||
assert res.headers[cors_header] == cors_header_value
|
||||
|
||||
|
||||
def test_cors_proxy_only_forwards_explicit_proxy_headers():
|
||||
class CaptureHeadersHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.server.captured_headers = dict(self.headers)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(b"ok")
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler)
|
||||
target.captured_headers = {}
|
||||
target_thread = threading.Thread(target=target.serve_forever, daemon=True)
|
||||
target_thread.start()
|
||||
|
||||
try:
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.api_key = TEST_API_KEY
|
||||
server.ui_mcp_proxy = True
|
||||
server.start()
|
||||
|
||||
res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={
|
||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||
"Proxy-Authorization": "Basic secret",
|
||||
"X-Api-Key": TEST_API_KEY,
|
||||
"Cookie": "session=secret",
|
||||
"x-llama-server-proxy-header-accept": "application/json",
|
||||
"x-llama-server-proxy-header-authorization": "Bearer explicit",
|
||||
})
|
||||
|
||||
assert res.status_code == 200
|
||||
captured = {key.lower(): value for key, value in target.captured_headers.items()}
|
||||
assert captured["accept"] == "application/json"
|
||||
assert captured["authorization"] == "Bearer explicit"
|
||||
assert "proxy-authorization" not in captured
|
||||
assert "x-api-key" not in captured
|
||||
assert "cookie" not in captured
|
||||
finally:
|
||||
target.shutdown()
|
||||
target.server_close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"media_path, image_url, success",
|
||||
[
|
||||
|
||||
@@ -94,7 +94,7 @@ class ServerProcess:
|
||||
enable_ctx_shift: int | None = False
|
||||
spec_draft_n_min: int | None = None
|
||||
spec_draft_n_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
no_ui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None
|
||||
reasoning: Literal['on', 'off', 'auto'] | None = None
|
||||
@@ -107,7 +107,7 @@ class ServerProcess:
|
||||
cache_ram: int | None = None
|
||||
no_cache_idle_slots: bool = False
|
||||
log_path: str | None = None
|
||||
webui_mcp_proxy: bool = False
|
||||
ui_mcp_proxy: bool = False
|
||||
backend_sampling: bool = False
|
||||
gcp_compat: bool = False
|
||||
|
||||
@@ -225,8 +225,8 @@ class ServerProcess:
|
||||
server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max])
|
||||
if self.spec_draft_n_min:
|
||||
server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.no_ui:
|
||||
server_args.append("--no-ui")
|
||||
if self.no_models_autoload:
|
||||
server_args.append("--no-models-autoload")
|
||||
if self.jinja:
|
||||
@@ -251,8 +251,8 @@ class ServerProcess:
|
||||
server_args.extend(["--cache-ram", self.cache_ram])
|
||||
if self.no_cache_idle_slots:
|
||||
server_args.append("--no-cache-idle-slots")
|
||||
if self.webui_mcp_proxy:
|
||||
server_args.append("--webui-mcp-proxy")
|
||||
if self.ui_mcp_proxy:
|
||||
server_args.append("--ui-mcp-proxy")
|
||||
if self.backend_sampling:
|
||||
server_args.append("--backend_sampling")
|
||||
if self.gcp_compat:
|
||||
|
||||
Vendored
+10
@@ -19,6 +19,10 @@ import type {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -52,6 +56,7 @@ import type {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
@@ -83,6 +88,10 @@ declare global {
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelListResponse,
|
||||
ApiProcessingState,
|
||||
ApiRouterModelMeta,
|
||||
@@ -120,6 +129,7 @@ declare global {
|
||||
// Model types
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
// Settings types
|
||||
SettingsChatServiceOptions,
|
||||
SettingsConfigValue,
|
||||
|
||||
+12
-3
@@ -10,7 +10,7 @@
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { copyToClipboard, deriveAgenticSections } from '$lib/utils';
|
||||
import { copyToClipboard, deriveAgenticSections, modelLoadProgressText } from '$lib/utils';
|
||||
import { AgenticSectionType } from '$lib/enums';
|
||||
import { REASONING_TAGS } from '$lib/constants/agentic';
|
||||
import { tick } from 'svelte';
|
||||
@@ -185,6 +185,13 @@
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming);
|
||||
|
||||
// during a router auto-load the message has no model yet, so target the selected one
|
||||
let loadTargetModel = $derived(message.model ?? modelsStore.selectedModelName);
|
||||
let modelLoadProgress = $derived(
|
||||
isRouter && loadTargetModel ? modelsStore.getLoadProgress(loadTargetModel) : null
|
||||
);
|
||||
let modelLoadingText = $derived(modelLoadProgressText(modelLoadProgress));
|
||||
|
||||
let showProcessingInfoTop = $derived(
|
||||
message?.role === MessageRole.ASSISTANT &&
|
||||
isActivelyProcessing &&
|
||||
@@ -220,7 +227,8 @@
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{processingState.getPromptProgressText() ??
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
@@ -252,7 +260,8 @@
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{processingState.getPromptProgressText() ??
|
||||
{modelLoadingText ??
|
||||
processingState.getPromptProgressText() ??
|
||||
processingState.getProcessingMessage() ??
|
||||
'Processing...'}
|
||||
</span>
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction, modelLoadProgressText } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
option: ModelOption;
|
||||
@@ -50,11 +51,15 @@
|
||||
(serverStatus === ServerModelStatus.LOADED || isSleeping) && !isOperationInProgress
|
||||
);
|
||||
let isLoading = $derived(serverStatus === ServerModelStatus.LOADING || isOperationInProgress);
|
||||
|
||||
let loadProgress = $derived(isLoading ? modelsStore.getLoadProgress(option.model) : null);
|
||||
let loadPercent = $derived(Math.round(modelLoadFraction(loadProgress) * 100));
|
||||
let loadTitle = $derived(modelLoadProgressText(loadProgress));
|
||||
</script>
|
||||
|
||||
<div
|
||||
class={[
|
||||
'group flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'group relative flex w-full items-center gap-2 rounded-sm p-2 text-left text-sm transition focus:outline-none',
|
||||
'cursor-pointer hover:bg-muted focus:bg-muted',
|
||||
(isSelected || isHighlighted) && 'bg-accent text-accent-foreground',
|
||||
!(isSelected || isHighlighted) && 'hover:bg-accent hover:text-accent-foreground',
|
||||
@@ -62,6 +67,7 @@
|
||||
]}
|
||||
role="option"
|
||||
aria-selected={isSelected || isHighlighted}
|
||||
title={loadTitle}
|
||||
tabindex="0"
|
||||
onclick={() => onSelect(option.id)}
|
||||
onmouseenter={onMouseEnter}
|
||||
@@ -188,4 +194,15 @@
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
export const API_MODELS = {
|
||||
LIST: '/v1/models',
|
||||
LOAD: '/models/load',
|
||||
UNLOAD: '/models/unload'
|
||||
UNLOAD: '/models/unload',
|
||||
SSE: '/models/sse'
|
||||
};
|
||||
|
||||
// chat completion routes, the control route drives realtime inference (e.g. end reasoning)
|
||||
|
||||
@@ -37,6 +37,8 @@ export * from './mcp-form';
|
||||
export * from './mcp-resource';
|
||||
export * from './message-export';
|
||||
export * from './model-id';
|
||||
export * from './model-loading';
|
||||
export * from './sse';
|
||||
export * from './precision';
|
||||
export * from './processing-info';
|
||||
export * from './pwa';
|
||||
|
||||
@@ -51,6 +51,9 @@ export const EXPECTED_THEMED_ICON_PAIR_COUNT = 2;
|
||||
/** CORS proxy URL query parameter name */
|
||||
export const CORS_PROXY_URL_PARAM = 'url';
|
||||
|
||||
/** Header prefix for headers that should be forwarded by the CORS proxy */
|
||||
export const CORS_PROXY_HEADER_PREFIX = 'x-llama-server-proxy-header-';
|
||||
|
||||
/** Number of trailing characters to keep visible when partially redacting mcp-session-id */
|
||||
export const MCP_SESSION_ID_VISIBLE_CHARS = 5;
|
||||
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
/**
|
||||
* Labels shown while a model loads, keyed by the stage reported on /models/sse.
|
||||
*/
|
||||
export const MODEL_LOAD_STAGE_LABELS: Record<ApiModelLoadStage, string> = {
|
||||
text_model: 'Loading weights',
|
||||
spec_model: 'Loading draft',
|
||||
mmproj_model: 'Loading projector'
|
||||
};
|
||||
|
||||
/**
|
||||
* Share of the bar reserved for each load phase after text_model.
|
||||
* text_model fills the rest, so a plain model reaches 100% on its own.
|
||||
*/
|
||||
export const MODEL_LOAD_TAIL_SHARE = 0.1;
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Server-sent events wire format, shared by the chat stream and the
|
||||
* /models/sse status feed (text/event-stream).
|
||||
*/
|
||||
|
||||
// blank line between two events
|
||||
export const SSE_RECORD_SEPARATOR = '\n\n';
|
||||
|
||||
// line break inside an event
|
||||
export const SSE_LINE_SEPARATOR = '\n';
|
||||
|
||||
// data field prefix, the value follows after an optional space
|
||||
export const SSE_DATA_PREFIX = 'data:';
|
||||
|
||||
// end-of-stream marker on the chat completion stream
|
||||
export const SSE_DONE_MARKER = '[DONE]';
|
||||
@@ -54,7 +54,7 @@ export {
|
||||
|
||||
export { ModelModality } from './model.enums';
|
||||
|
||||
export { ServerRole, ServerModelStatus } from './server.enums';
|
||||
export { ServerRole, ServerModelStatus, ServerModelsSseEventType } from './server.enums';
|
||||
|
||||
export { ParameterSource, SyncableParameterType, SettingsFieldType } from './settings.enums';
|
||||
|
||||
|
||||
@@ -19,3 +19,17 @@ export enum ServerModelStatus {
|
||||
SLEEPING = 'sleeping',
|
||||
FAILED = 'failed'
|
||||
}
|
||||
|
||||
/**
|
||||
* /models/sse event type enum - discriminates the records broadcast on the
|
||||
* model status feed in ROUTER mode. Matches the event names emitted by
|
||||
* tools/server/server-models.cpp from the C++ server.
|
||||
*/
|
||||
export enum ServerModelsSseEventType {
|
||||
STATUS_CHANGE = 'status_change',
|
||||
MODEL_STATUS = 'model_status',
|
||||
STATUS_UPDATE = 'status_update',
|
||||
MODELS_RELOAD = 'models_reload',
|
||||
MODEL_REMOVE = 'model_remove',
|
||||
DOWNLOAD_PROGRESS = 'download_progress'
|
||||
}
|
||||
|
||||
@@ -10,7 +10,10 @@ import {
|
||||
SETTINGS_KEYS,
|
||||
API_CHAT,
|
||||
API_SLOTS,
|
||||
CONTROL_ACTION
|
||||
CONTROL_ACTION,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX,
|
||||
SSE_DONE_MARKER
|
||||
} from '$lib/constants';
|
||||
import {
|
||||
AttachmentType,
|
||||
@@ -18,8 +21,7 @@ import {
|
||||
FileTypeAudio,
|
||||
MessageRole,
|
||||
MimeTypeAudio,
|
||||
ReasoningFormat,
|
||||
UrlProtocol
|
||||
ReasoningFormat
|
||||
} from '$lib/enums';
|
||||
import type {
|
||||
ApiChatMessageContentPart,
|
||||
@@ -642,15 +644,15 @@ export class ChatService {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
chunk += decoder.decode(value, { stream: true });
|
||||
const lines = chunk.split('\n');
|
||||
const lines = chunk.split(SSE_LINE_SEPARATOR);
|
||||
chunk = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (abortSignal?.aborted) break;
|
||||
|
||||
if (line.startsWith(UrlProtocol.DATA)) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') {
|
||||
if (line.startsWith(SSE_DATA_PREFIX)) {
|
||||
const data = line.slice(SSE_DATA_PREFIX.length).trim();
|
||||
if (data === SSE_DONE_MARKER) {
|
||||
streamFinished = true;
|
||||
|
||||
continue;
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
DEFAULT_MCP_CONFIG,
|
||||
DEFAULT_CLIENT_VERSION,
|
||||
DEFAULT_IMAGE_MIME_TYPE,
|
||||
CORS_PROXY_HEADER_PREFIX,
|
||||
MCP_PARTIAL_REDACT_HEADERS,
|
||||
CORS_PROXY_ENDPOINT
|
||||
} from '$lib/constants';
|
||||
@@ -133,6 +134,20 @@ export class MCPService {
|
||||
return details;
|
||||
}
|
||||
|
||||
private static addRequestHeaders(
|
||||
requestHeaders: Headers,
|
||||
headers: HeadersInit,
|
||||
useProxy: boolean
|
||||
) {
|
||||
for (const [key, value] of new Headers(headers).entries()) {
|
||||
const proxiedKey =
|
||||
useProxy && !key.toLowerCase().startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||
? `${CORS_PROXY_HEADER_PREFIX}${key}`
|
||||
: key;
|
||||
requestHeaders.set(proxiedKey, value);
|
||||
}
|
||||
}
|
||||
|
||||
private static summarizeError(error: unknown): Record<string, unknown> {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
@@ -271,15 +286,11 @@ export class MCPService {
|
||||
const requestHeaders = new Headers(baseInit.headers);
|
||||
|
||||
if (typeof Request !== 'undefined' && input instanceof Request) {
|
||||
for (const [key, value] of input.headers.entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
this.addRequestHeaders(requestHeaders, input.headers, useProxy);
|
||||
}
|
||||
|
||||
if (init?.headers) {
|
||||
for (const [key, value] of new Headers(init.headers).entries()) {
|
||||
requestHeaders.set(key, value);
|
||||
}
|
||||
this.addRequestHeaders(requestHeaders, init.headers, useProxy);
|
||||
}
|
||||
|
||||
const request = this.createDiagnosticRequestDetails(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { base } from '$app/paths';
|
||||
import { SvelteMap, SvelteSet } from 'svelte/reactivity';
|
||||
import { toast } from 'svelte-sonner';
|
||||
import { ServerModelStatus, ModelModality } from '$lib/enums';
|
||||
import { ServerModelStatus, ServerModelsSseEventType, ModelModality } from '$lib/enums';
|
||||
import { ModelsService } from '$lib/services/models.service';
|
||||
import { PropsService } from '$lib/services/props.service';
|
||||
import { serverStore, isRouterMode } from '$lib/stores/server.svelte';
|
||||
@@ -8,11 +9,15 @@ import {
|
||||
detectThinkingSupport,
|
||||
detectThinkingSupportWithReason
|
||||
} from '$lib/utils/chat-template-thinking-detector';
|
||||
import { TTLCache } from '$lib/utils';
|
||||
import { TTLCache, getAuthHeaders } from '$lib/utils';
|
||||
import {
|
||||
MODEL_PROPS_CACHE_TTL_MS,
|
||||
MODEL_PROPS_CACHE_MAX_ENTRIES,
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY
|
||||
FAVORITE_MODELS_LOCALSTORAGE_KEY,
|
||||
API_MODELS,
|
||||
SSE_RECORD_SEPARATOR,
|
||||
SSE_LINE_SEPARATOR,
|
||||
SSE_DATA_PREFIX
|
||||
} from '$lib/constants';
|
||||
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
@@ -55,6 +60,15 @@ class ModelsStore {
|
||||
private modelUsage = $state<Map<string, SvelteSet<string>>>(new Map());
|
||||
private modelLoadingStates = new SvelteMap<string, boolean>();
|
||||
|
||||
// /models/sse feed state, the single source of truth for status and load progress
|
||||
private statusAbort: AbortController | null = null;
|
||||
private statusReaderActive = false;
|
||||
private loadProgress = new SvelteMap<string, ModelLoadProgress>();
|
||||
private statusWaiters = new Map<
|
||||
string,
|
||||
{ target: ServerModelStatus; resolve: () => void; reject: (e: Error) => void }
|
||||
>();
|
||||
|
||||
favoriteModelIds = $state<Set<string>>(this.loadFavoritesFromStorage());
|
||||
|
||||
/**
|
||||
@@ -626,49 +640,218 @@ class ModelsStore {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* WORKAROUND: Polling for model status after load/unload operations.
|
||||
*
|
||||
* Currently, `/models/load` and `/models/unload` return success before
|
||||
* the operation actually completes on the server.
|
||||
*
|
||||
* TODO: Remove polling once llama-server properly waits for the operation
|
||||
* to complete before returning success.
|
||||
*/
|
||||
|
||||
private static readonly STATUS_POLL_INTERVAL = 500;
|
||||
// reconnect delay after the feed drops or the server is not ready yet
|
||||
private static readonly SSE_RECONNECT_MS = 1000;
|
||||
|
||||
/**
|
||||
* Poll for expected model status after load/unload operation.
|
||||
* Keeps polling until the model reaches the expected status or fails.
|
||||
* Open the /models/sse feed and keep it live with auto reconnect.
|
||||
* Idempotent and router mode only. The feed drives status and progress,
|
||||
* so it replaces any post-operation polling.
|
||||
*/
|
||||
private async pollForModelStatus(
|
||||
modelId: string,
|
||||
expectedStatus: ServerModelStatus
|
||||
): Promise<void> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
await this.fetchRouterModels();
|
||||
subscribeStatus(): void {
|
||||
if (this.statusReaderActive) return;
|
||||
if (!isRouterMode()) return;
|
||||
|
||||
const currentStatus = this.getModelStatus(modelId);
|
||||
if (currentStatus === expectedStatus) return;
|
||||
this.statusReaderActive = true;
|
||||
this.statusAbort = new AbortController();
|
||||
void this.runStatusReader(this.statusAbort.signal);
|
||||
}
|
||||
|
||||
if (currentStatus === ServerModelStatus.FAILED) {
|
||||
throw new Error(
|
||||
`Model failed to ${expectedStatus === ServerModelStatus.LOADED ? 'load' : 'unload'}`
|
||||
);
|
||||
/**
|
||||
* Close the /models/sse feed and drop transient progress.
|
||||
*/
|
||||
unsubscribeStatus(): void {
|
||||
this.statusReaderActive = false;
|
||||
this.statusAbort?.abort();
|
||||
this.statusAbort = null;
|
||||
this.loadProgress.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Current load progress for a model, or null when not loading.
|
||||
*/
|
||||
getLoadProgress(modelId: string): ModelLoadProgress | null {
|
||||
return this.loadProgress.get(modelId) ?? null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read the feed and reconnect until unsubscribed. Splits the byte stream
|
||||
* into SSE records on the blank line boundary.
|
||||
*/
|
||||
private async runStatusReader(signal: AbortSignal): Promise<void> {
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (!signal.aborted) {
|
||||
try {
|
||||
const response = await fetch(`${base}${API_MODELS.SSE}`, {
|
||||
headers: getAuthHeaders(),
|
||||
signal
|
||||
});
|
||||
|
||||
if (response.ok && response.body) {
|
||||
const reader = response.body.getReader();
|
||||
let buffer = '';
|
||||
|
||||
while (!signal.aborted) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
while (boundary !== -1) {
|
||||
this.handleStatusRecord(buffer.slice(0, boundary));
|
||||
buffer = buffer.slice(boundary + SSE_RECORD_SEPARATOR.length);
|
||||
boundary = buffer.indexOf(SSE_RECORD_SEPARATOR);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// network drop or abort falls through to the reconnect delay
|
||||
}
|
||||
|
||||
if (
|
||||
expectedStatus === ServerModelStatus.LOADED &&
|
||||
currentStatus === ServerModelStatus.UNLOADED &&
|
||||
attempt > 2
|
||||
) {
|
||||
throw new Error('Model was unloaded unexpectedly during loading');
|
||||
}
|
||||
if (signal.aborted) return;
|
||||
|
||||
attempt++;
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.STATUS_POLL_INTERVAL));
|
||||
await new Promise((resolve) => setTimeout(resolve, ModelsStore.SSE_RECONNECT_MS));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse one SSE record. The payload rides in the data lines as a JSON
|
||||
* envelope that carries its own model, event and data fields.
|
||||
*/
|
||||
private handleStatusRecord(record: string): void {
|
||||
const payload = record
|
||||
.split(SSE_LINE_SEPARATOR)
|
||||
.filter((line) => line.startsWith(SSE_DATA_PREFIX))
|
||||
.map((line) => line.slice(SSE_DATA_PREFIX.length).trim())
|
||||
.join(SSE_LINE_SEPARATOR);
|
||||
|
||||
if (payload.length === 0) return;
|
||||
|
||||
let envelope: ApiModelsSseEvent;
|
||||
try {
|
||||
envelope = JSON.parse(payload);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
this.applyStatusEvent(envelope);
|
||||
}
|
||||
|
||||
/**
|
||||
* Route one feed record by event kind. Only the status_* events carry a
|
||||
* status payload, models_reload triggers a list refresh, model_remove drops
|
||||
* the row, download_* belong to the download surface, not here.
|
||||
*/
|
||||
private applyStatusEvent(event: ApiModelsSseEvent): void {
|
||||
switch (event.event) {
|
||||
case ServerModelsSseEventType.STATUS_CHANGE:
|
||||
case ServerModelsSseEventType.MODEL_STATUS:
|
||||
case ServerModelsSseEventType.STATUS_UPDATE:
|
||||
this.applyModelStatus(event);
|
||||
break;
|
||||
case ServerModelsSseEventType.MODELS_RELOAD:
|
||||
void this.fetchRouterModels();
|
||||
break;
|
||||
case ServerModelsSseEventType.MODEL_REMOVE:
|
||||
this.removeRouterModel(event.model);
|
||||
break;
|
||||
case ServerModelsSseEventType.DOWNLOAD_PROGRESS:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply a status envelope: update the model row, track or clear progress,
|
||||
* settle any pending load or unload awaiter.
|
||||
*/
|
||||
private applyModelStatus(event: ApiModelsSseEvent): void {
|
||||
const model = event.model;
|
||||
const data = event.data;
|
||||
if (!model || !data?.status) return;
|
||||
|
||||
const status = data.status;
|
||||
|
||||
this.setRouterModelStatus(model, status);
|
||||
|
||||
if (status === ServerModelStatus.LOADING) {
|
||||
if (data.progress) this.loadProgress.set(model, data.progress);
|
||||
} else {
|
||||
this.loadProgress.delete(model);
|
||||
}
|
||||
|
||||
if (status === ServerModelStatus.LOADED) {
|
||||
void this.updateModelModalities(model);
|
||||
}
|
||||
|
||||
const failed =
|
||||
status === ServerModelStatus.FAILED ||
|
||||
(status === ServerModelStatus.UNLOADED && (data.exit_code ?? 0) !== 0);
|
||||
|
||||
if (failed) {
|
||||
this.rejectStatus(model, new Error(`Model failed: ${this.toDisplayName(model)}`));
|
||||
return;
|
||||
}
|
||||
|
||||
this.settleStatus(model, status);
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop a model row reported gone by the feed and settle its awaiters.
|
||||
*/
|
||||
private removeRouterModel(modelId: string): void {
|
||||
if (this.routerModels.findIndex((m) => m.id === modelId) === -1) return;
|
||||
|
||||
this.routerModels = this.routerModels.filter((m) => m.id !== modelId);
|
||||
this.loadProgress.delete(modelId);
|
||||
this.rejectStatus(modelId, new Error(`Model removed: ${this.toDisplayName(modelId)}`));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update one model row status in place, reassigning to trigger reactivity.
|
||||
*/
|
||||
private setRouterModelStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const idx = this.routerModels.findIndex((m) => m.id === modelId);
|
||||
if (idx === -1) return;
|
||||
|
||||
const current = this.routerModels[idx];
|
||||
if (current.status.value === status) return;
|
||||
|
||||
const next = [...this.routerModels];
|
||||
next[idx] = { ...current, status: { ...current.status, value: status } };
|
||||
this.routerModels = next;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register an awaiter that resolves when the feed reports target status.
|
||||
* One operation runs per model at a time, so one awaiter per model is kept.
|
||||
*/
|
||||
private waitForStatus(modelId: string, target: ServerModelStatus): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.statusWaiters.set(modelId, { target, resolve, reject });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve and drop the awaiter when the model reaches its target status.
|
||||
*/
|
||||
private settleStatus(modelId: string, status: ServerModelStatus): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter && waiter.target === status) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reject and drop the awaiter for a model.
|
||||
*/
|
||||
private rejectStatus(modelId: string, error: Error): void {
|
||||
const waiter = this.statusWaiters.get(modelId);
|
||||
if (waiter) {
|
||||
this.statusWaiters.delete(modelId);
|
||||
waiter.reject(error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -679,12 +862,18 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
// the feed drives completion, so it must be live before the request
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedLoaded = this.waitForStatus(modelId, ServerModelStatus.LOADED);
|
||||
reachedLoaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.load(modelId);
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.LOADED);
|
||||
await this.updateModelModalities(modelId);
|
||||
await reachedLoaded;
|
||||
toast.success(`Model loaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('load failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to load model';
|
||||
toast.error(`Failed to load model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -700,11 +889,17 @@ class ModelsStore {
|
||||
this.modelLoadingStates.set(modelId, true);
|
||||
this.error = null;
|
||||
|
||||
this.subscribeStatus();
|
||||
|
||||
const reachedUnloaded = this.waitForStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
reachedUnloaded.catch(() => {});
|
||||
|
||||
try {
|
||||
await ModelsService.unload(modelId);
|
||||
await this.pollForModelStatus(modelId, ServerModelStatus.UNLOADED);
|
||||
await reachedUnloaded;
|
||||
toast.info(`Model unloaded: ${this.toDisplayName(modelId)}`);
|
||||
} catch (error) {
|
||||
this.rejectStatus(modelId, error instanceof Error ? error : new Error('unload failed'));
|
||||
this.error = error instanceof Error ? error.message : 'Failed to unload model';
|
||||
toast.error(`Failed to unload model: ${this.toDisplayName(modelId)}`);
|
||||
throw error;
|
||||
@@ -783,6 +978,9 @@ class ModelsStore {
|
||||
}
|
||||
|
||||
clear(): void {
|
||||
this.unsubscribeStatus();
|
||||
this.statusWaiters.forEach((waiter) => waiter.reject(new Error('Models store cleared')));
|
||||
this.statusWaiters.clear();
|
||||
this.models = [];
|
||||
this.routerModels = [];
|
||||
this.loading = false;
|
||||
|
||||
Vendored
+47
-1
@@ -1,4 +1,10 @@
|
||||
import type { ContentPartType, FileTypeAudio, ServerModelStatus, ServerRole } from '$lib/enums';
|
||||
import type {
|
||||
ContentPartType,
|
||||
FileTypeAudio,
|
||||
ServerModelStatus,
|
||||
ServerModelsSseEventType,
|
||||
ServerRole
|
||||
} from '$lib/enums';
|
||||
import type { ChatMessagePromptProgress, ChatRole } from './chat';
|
||||
|
||||
export type AudioInputFormat = FileTypeAudio.WAV | FileTypeAudio.MP3;
|
||||
@@ -96,6 +102,46 @@ export interface ApiModelDataEntry {
|
||||
meta?: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load stage reported by the /models/sse feed, in load order.
|
||||
*/
|
||||
export type ApiModelLoadStage = 'text_model' | 'spec_model' | 'mmproj_model';
|
||||
|
||||
/**
|
||||
* Load progress snapshot: the full ordered stage plan, the active stage,
|
||||
* and its fractional value (0.0 -> 1.0).
|
||||
*/
|
||||
export interface ApiModelsSseProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Status payload carried by a /models/sse envelope.
|
||||
* exit_code appears on unload.
|
||||
*/
|
||||
export interface ApiModelsSseData {
|
||||
status: ServerModelStatus;
|
||||
progress?: ApiModelsSseProgress;
|
||||
exit_code?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Event kind multiplexed on the /models/sse feed.
|
||||
* Only the status_* events carry a status payload, models_reload signals a
|
||||
* full list refresh, model_remove drops a row, download_* drive download UI.
|
||||
*/
|
||||
/**
|
||||
* One /models/sse record. event discriminates the kind, model names the
|
||||
* target instance, data carries the status payload when present.
|
||||
*/
|
||||
export interface ApiModelsSseEvent {
|
||||
model: string;
|
||||
event: ServerModelsSseEventType;
|
||||
data: ApiModelsSseData;
|
||||
}
|
||||
|
||||
export interface ApiModelDetails {
|
||||
name: string;
|
||||
model: string;
|
||||
|
||||
@@ -11,6 +11,10 @@ export type {
|
||||
ApiChatMessageData,
|
||||
ApiModelStatus,
|
||||
ApiModelDataEntry,
|
||||
ApiModelLoadStage,
|
||||
ApiModelsSseProgress,
|
||||
ApiModelsSseData,
|
||||
ApiModelsSseEvent,
|
||||
ApiModelDetails,
|
||||
ApiModelListResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
@@ -70,7 +74,12 @@ export type {
|
||||
} from './database';
|
||||
|
||||
// Model types
|
||||
export type { ModelModalities, ModelOption, ModalityCapabilities } from './models';
|
||||
export type {
|
||||
ModelModalities,
|
||||
ModelOption,
|
||||
ModelLoadProgress,
|
||||
ModalityCapabilities
|
||||
} from './models';
|
||||
|
||||
// Settings types
|
||||
export type {
|
||||
|
||||
Vendored
+12
-1
@@ -1,4 +1,4 @@
|
||||
import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api';
|
||||
import type { ApiModelDataEntry, ApiModelDetails, ApiModelLoadStage } from '$lib/types/api';
|
||||
|
||||
export interface ModelModalities {
|
||||
vision: boolean;
|
||||
@@ -20,6 +20,17 @@ export interface ModelOption {
|
||||
tags?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Ephemeral UI-only load progress for one model instance.
|
||||
* Lives only while a load runs, driven by the /models/sse feed.
|
||||
* stage is absent until the feed reports its first stage.
|
||||
*/
|
||||
export interface ModelLoadProgress {
|
||||
stages: ApiModelLoadStage[];
|
||||
current: ApiModelLoadStage;
|
||||
value: number;
|
||||
}
|
||||
|
||||
export interface ParsedModelId {
|
||||
raw: string;
|
||||
orgName: string | null;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { REDACTED_HEADERS } from '$lib/constants';
|
||||
import { CORS_PROXY_HEADER_PREFIX, REDACTED_HEADERS } from '$lib/constants';
|
||||
import { redactValue } from './redact';
|
||||
|
||||
/**
|
||||
@@ -52,11 +52,20 @@ export function sanitizeHeaders(
|
||||
|
||||
for (const [key, value] of normalized.entries()) {
|
||||
const normalizedKey = key.toLowerCase();
|
||||
const partialChars = partialRedactHeaders?.get(normalizedKey);
|
||||
const unproxiedKey = normalizedKey.startsWith(CORS_PROXY_HEADER_PREFIX)
|
||||
? normalizedKey.slice(CORS_PROXY_HEADER_PREFIX.length)
|
||||
: normalizedKey;
|
||||
const partialChars =
|
||||
partialRedactHeaders?.get(normalizedKey) ?? partialRedactHeaders?.get(unproxiedKey);
|
||||
|
||||
if (partialChars !== undefined) {
|
||||
sanitized[key] = redactValue(value, partialChars);
|
||||
} else if (REDACTED_HEADERS.has(normalizedKey) || redactedHeaders.has(normalizedKey)) {
|
||||
} else if (
|
||||
REDACTED_HEADERS.has(normalizedKey) ||
|
||||
REDACTED_HEADERS.has(unproxiedKey) ||
|
||||
redactedHeaders.has(normalizedKey) ||
|
||||
redactedHeaders.has(unproxiedKey)
|
||||
) {
|
||||
sanitized[key] = redactValue(value);
|
||||
} else {
|
||||
sanitized[key] = value;
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
*/
|
||||
|
||||
import { base } from '$app/paths';
|
||||
import { CORS_PROXY_ENDPOINT, CORS_PROXY_URL_PARAM } from '$lib/constants';
|
||||
import {
|
||||
CORS_PROXY_ENDPOINT,
|
||||
CORS_PROXY_HEADER_PREFIX,
|
||||
CORS_PROXY_URL_PARAM
|
||||
} from '$lib/constants';
|
||||
|
||||
/**
|
||||
* Build a proxied URL that routes through llama-server's CORS proxy.
|
||||
@@ -28,7 +32,7 @@ export function buildProxiedHeaders(headers: Record<string, string>): Record<str
|
||||
const proxiedHeaders: Record<string, string> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
proxiedHeaders[`x-proxy-header-${key}`] = value;
|
||||
proxiedHeaders[`${CORS_PROXY_HEADER_PREFIX}${key}`] = value;
|
||||
}
|
||||
|
||||
return proxiedHeaders;
|
||||
|
||||
@@ -44,6 +44,9 @@ export { buildProxiedUrl, buildProxiedHeaders } from './cors-proxy';
|
||||
// URL utilities
|
||||
export { extractRootDomain, sanitizeExternalUrl } from './url';
|
||||
|
||||
// Progress helpers
|
||||
export { modelLoadFraction, modelLoadProgressText } from './progress';
|
||||
|
||||
// Conversation utilities
|
||||
export { createMessageCountMap, getMessageCount } from './conversation-utils';
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user