mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
server: fix checkpoints creation (#22929)
* common : add common_chat_split_by_role * cont : fix spans to reach end of message * server: fix checkpoints creation - extract message_spans from chat templates - find the prompt token position before the latest user message - split prompt batching at that position - create a context checkpoint before the latest user input - avoid periodic mid-prompt checkpoints when that position is known - handle multimodal prompts when mapping text/template positions to server prompt tokens - add --checkpoint-min-step to control minimum spacing between checkpoints * cont : clean-up * Support autoparser detection for message barriers * server: fix message span delimiter and update docs --------- Co-authored-by: Alde Rojas <hello@alde.dev> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
This commit is contained in:
+7
-4
@@ -1334,12 +1334,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"-cpent", "--checkpoint-every-n-tokens"}, "N",
|
||||
string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt),
|
||||
{"-cms", "--checkpoint-min-step"}, "N",
|
||||
string_format("minimum spacing between context checkpoints in tokens (default: %d, 0 = no minimum)", params.checkpoint_min_step),
|
||||
[](common_params & params, int value) {
|
||||
params.checkpoint_every_nt = value;
|
||||
if (value < 0) {
|
||||
throw std::invalid_argument("checkpoint-min-step must be non-negative");
|
||||
}
|
||||
params.checkpoint_min_step = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
).set_env("LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-cram", "--cache-ram"}, "N",
|
||||
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
|
||||
|
||||
@@ -310,6 +310,8 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm
|
||||
|
||||
namespace autoparser {
|
||||
|
||||
static const std::string ERR_TMPL = "#**ERROR**#";
|
||||
|
||||
std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
|
||||
generation_params tmpl_params;
|
||||
tmpl_params.messages = params.messages;
|
||||
@@ -326,7 +328,7 @@ std::string apply_template(const common_chat_template & tmpl, const template_par
|
||||
return common_chat_template_direct_apply(tmpl, tmpl_params);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_DBG("Template application failed: %s\n", e.what());
|
||||
return "";
|
||||
return ERR_TMPL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,7 +349,7 @@ std::optional<compare_variants_result> compare_variants(
|
||||
std::string output_B = apply_template(tmpl, params_B);
|
||||
|
||||
// Check for template application failures
|
||||
if (output_A.empty() || output_B.empty()) {
|
||||
if (output_A == ERR_TMPL || output_B == ERR_TMPL) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
@@ -377,6 +377,8 @@ struct analyze_tools : analyze_base {
|
||||
|
||||
struct autoparser {
|
||||
jinja::caps jinja_caps;
|
||||
std::string user_start;
|
||||
std::string assistant_start;
|
||||
analyze_reasoning reasoning;
|
||||
analyze_content content;
|
||||
analyze_tools tools;
|
||||
@@ -387,6 +389,10 @@ struct autoparser {
|
||||
|
||||
autoparser() = default;
|
||||
|
||||
// Find the starting marker for the user message and assistant message
|
||||
std::string detect_user_start_marker(const common_chat_template & tmpl);
|
||||
std::string detect_assistant_start_marker(const common_chat_template & tmpl);
|
||||
|
||||
// Run full differential analysis on a template
|
||||
void analyze_template(const common_chat_template & tmpl);
|
||||
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
#define ANSI_RESET "\033[0m"
|
||||
#define ANSI_PURPLE "\033[1m\x1b[38;5;126m"
|
||||
@@ -23,6 +26,7 @@ static const std::string FUN_SECOND = "SSS_SECOND_FUN_S";
|
||||
static const std::string ARG_FIRST = "AA_ARG_FST_AA";
|
||||
static const std::string ARG_SECOND = "BB_ARG_SND_BB";
|
||||
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
|
||||
static const std::string USER_MSG_TWO = "V_USER_MSG Hello END_V";
|
||||
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
|
||||
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
|
||||
static const std::string CALL_ID_001 = "call00001";
|
||||
@@ -71,6 +75,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
analysis.content.end = "<|END_OF_TURN_TOKEN|>";
|
||||
analysis.preserved_tokens.push_back("<|CHATBOT_TOKEN|>");
|
||||
analysis.preserved_tokens.push_back("<|END_OF_TURN_TOKEN|>");
|
||||
analysis.user_start = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Cohere Command R+]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
@@ -108,7 +113,59 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
||||
analysis.tools.function.close = "```";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
|
||||
}
|
||||
}
|
||||
},
|
||||
// Nemotron Nano v2
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("<SPECIAL_10>") != std::string::npos && tmpl.src.find("<SPECIAL_11>") != std::string::npos &&
|
||||
tmpl.src.find("<SPECIAL_12>") != std::string::npos && tmpl.src.find("<TOOL_RESPONSE>") != std::string::npos) {
|
||||
|
||||
analysis.tools.format.mode = tool_format::JSON_NATIVE;
|
||||
analysis.tools.format.section_start = "";
|
||||
analysis.tools.format.section_end = "";
|
||||
analysis.tools.format.per_call_start = "<TOOLCALL>";
|
||||
analysis.tools.format.per_call_end = "</TOOLCALL>";
|
||||
analysis.content.mode = content_mode::PLAIN;
|
||||
analysis.content.start = "";
|
||||
analysis.content.end = "";
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<think>\n\n";
|
||||
analysis.reasoning.end = "</think>";
|
||||
analysis.assistant_start = "<SPECIAL_11>Assistant";
|
||||
analysis.user_start = "<SPECIAL_11>User";
|
||||
analysis.preserved_tokens.clear();
|
||||
analysis.preserved_tokens.push_back("<SPECIAL_12>");
|
||||
analysis.preserved_tokens.push_back("<SPECIAL_11>");
|
||||
analysis.preserved_tokens.push_back("</think>");
|
||||
analysis.preserved_tokens.push_back("<TOOLCALL>");
|
||||
analysis.preserved_tokens.push_back("</TOOLCALL>");
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Nemotron Nano v2]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// Fireworks
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("{%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\\n\\n'"
|
||||
" + message['content'] | trim + '\\n' + system_prompt_suffix + '<|eot_id|>' -%}") != std::string::npos) {
|
||||
analysis.assistant_start = "<|start_header_id|>assistant<|end_header_id|>";
|
||||
analysis.user_start = "<|start_header_id|>user<|end_header_id|>";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Fireworks v2]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// Solar Open
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("<|begin|>assistant<|think|><|end|>") != std::string::npos) {
|
||||
analysis.assistant_start = "<|begin|>assistant";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Solar Open]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
// Apriel 1.6
|
||||
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
|
||||
if (tmpl.src.find("if not loop.last and '[BEGIN FINAL RESPONSE]' in asst_text") != std::string::npos) {
|
||||
analysis.user_start = "<|begin_user|>";
|
||||
analysis.assistant_start = "<|begin_assistant|>";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: Apriel 1.6]\n" ANSI_RESET);
|
||||
}
|
||||
},
|
||||
|
||||
});
|
||||
|
||||
// Common JSON structures
|
||||
@@ -166,6 +223,8 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
||||
reasoning = analyze_reasoning(tmpl, jinja_caps.supports_tool_calls);
|
||||
content = analyze_content(tmpl, reasoning);
|
||||
tools = analyze_tools(jinja_caps.supports_tool_calls ? analyze_tools(tmpl, jinja_caps, reasoning) : analyze_tools());
|
||||
assistant_start = detect_assistant_start_marker(tmpl);
|
||||
user_start = detect_user_start_marker(tmpl);
|
||||
collect_preserved_tokens();
|
||||
|
||||
for (auto & workaround : workarounds) {
|
||||
@@ -173,6 +232,8 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
||||
}
|
||||
|
||||
LOG_DBG("\n--- Reasoning & Content Structure ---\n");
|
||||
LOG_DBG("user_msg_start: %s\n", user_start.c_str());
|
||||
LOG_DBG("assistant_msg_start: %s\n", assistant_start.c_str());
|
||||
LOG_DBG("reasoning_mode: %s\n", mode_to_str(reasoning.mode).c_str());
|
||||
LOG_DBG("reasoning_start: '%s'\n", reasoning.start.c_str());
|
||||
LOG_DBG("reasoning_end: '%s'\n", reasoning.end.c_str());
|
||||
@@ -245,6 +306,120 @@ void autoparser::collect_preserved_tokens() {
|
||||
add_token(tools.call_id.suffix);
|
||||
}
|
||||
|
||||
std::string autoparser::detect_assistant_start_marker(const common_chat_template & tmpl) {
|
||||
json user_msg = json{
|
||||
{ "role", "user" },
|
||||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
json assistant_no_reasoning = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", ASSISTANT_MSG }
|
||||
};
|
||||
|
||||
template_params params;
|
||||
params.messages = json::array({ user_msg });
|
||||
params.add_generation_prompt = false;
|
||||
params.enable_thinking = true;
|
||||
|
||||
auto comparison = compare_variants(
|
||||
tmpl, params, [&](template_params & p) {
|
||||
p.messages = json::array({ user_msg, assistant_no_reasoning });
|
||||
}
|
||||
);
|
||||
|
||||
if (!comparison) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Template application failed, skipping assistant start detection\n" ANSI_RESET, __func__);
|
||||
return "";
|
||||
}
|
||||
|
||||
auto usermsg = comparison->diff.right;
|
||||
if (usermsg.find(ASSISTANT_MSG) == std::string::npos) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Did not find assistant message in assistant message block, skipping detection\n" ANSI_RESET, __func__);
|
||||
}
|
||||
|
||||
auto ast_prefix = usermsg.substr(0, usermsg.find(ASSISTANT_MSG));
|
||||
if (!reasoning.start.empty() && ast_prefix.find(trim_whitespace(reasoning.start)) != std::string::npos) {
|
||||
ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.start)));
|
||||
}
|
||||
if (!reasoning.end.empty() && ast_prefix.find(trim_whitespace(reasoning.end)) != std::string::npos) {
|
||||
ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.end)));
|
||||
}
|
||||
return trim_whitespace(ast_prefix);
|
||||
}
|
||||
|
||||
std::string autoparser::detect_user_start_marker(const common_chat_template & tmpl) {
|
||||
json user_msg = json{
|
||||
{ "role", "user" },
|
||||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
json assistant = json{
|
||||
{ "role", "assistant" },
|
||||
{ "content", ASSISTANT_MSG }
|
||||
};
|
||||
|
||||
json user_msg_two = json{
|
||||
{ "role", "user" },
|
||||
{ "content", USER_MSG_TWO }
|
||||
};
|
||||
|
||||
template_params params;
|
||||
params.messages = json::array({});
|
||||
params.add_generation_prompt = false;
|
||||
params.enable_thinking = true;
|
||||
|
||||
auto comparison = compare_variants(
|
||||
tmpl, params, [&](template_params & p) {
|
||||
p.messages = json::array({ user_msg });
|
||||
}
|
||||
);
|
||||
|
||||
if (!comparison) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Template application failed, unsupported empty messages? trying complex variant\n" ANSI_RESET, __func__);
|
||||
params.messages = json::array({ user_msg_two, assistant });
|
||||
comparison = compare_variants(
|
||||
tmpl, params, [&](template_params & p) {
|
||||
p.messages = json::array({ user_msg_two, assistant, user_msg });
|
||||
}
|
||||
);
|
||||
if (!comparison) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Template application failed for reserve variant, aborting\n" ANSI_RESET, __func__);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
auto usermsg = comparison->diff.right;
|
||||
if (usermsg.find(USER_MSG) == std::string::npos) {
|
||||
LOG_DBG(ANSI_ORANGE "%s: Did not find user message in user message block, aborting detection\n" ANSI_RESET, __func__);
|
||||
}
|
||||
|
||||
if (usermsg.find(ASSISTANT_MSG) != std::string::npos) {
|
||||
usermsg = usermsg.substr(usermsg.find(ASSISTANT_MSG) + ASSISTANT_MSG.size());
|
||||
}
|
||||
|
||||
auto candidate = usermsg.substr(0, usermsg.find(USER_MSG));
|
||||
auto candidate_split = segmentize_markers(candidate);
|
||||
std::stringstream result;
|
||||
bool encountered_marker = false;
|
||||
for (const auto & mrk : candidate_split) {
|
||||
std::string lower_mrk = std::string(mrk.value);
|
||||
std::transform(lower_mrk.begin(), lower_mrk.end(), lower_mrk.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
// heuristic to weed out potential end markers, but only at the start
|
||||
if (mrk.type == segment_type::MARKER && !encountered_marker &&
|
||||
(lower_mrk.find("end") != std::string::npos || lower_mrk.find("close") != std::string::npos)) {
|
||||
continue;
|
||||
}
|
||||
if (mrk.type == segment_type::TEXT && !encountered_marker && trim_whitespace(mrk.value).empty()) {
|
||||
continue;
|
||||
}
|
||||
encountered_marker |= mrk.type == segment_type::MARKER;
|
||||
result << mrk.value;
|
||||
}
|
||||
return trim_whitespace(result.str());
|
||||
}
|
||||
|
||||
analyze_reasoning::analyze_reasoning(const common_chat_template & tmpl, bool supports_tools)
|
||||
: analyze_base(tmpl) {
|
||||
LOG_DBG(ANSI_PURPLE "=== Starting differential analysis ===\n" ANSI_RESET);
|
||||
|
||||
@@ -90,6 +90,45 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
}
|
||||
});
|
||||
|
||||
return spans;
|
||||
}
|
||||
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
@@ -1042,6 +1081,14 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
|
||||
@@ -1181,6 +1228,11 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "<|channel>thought";
|
||||
@@ -2393,6 +2445,19 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
struct autoparser::autoparser autoparser;
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
|
||||
|
||||
@@ -143,6 +143,17 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
std::string name;
|
||||
std::string description;
|
||||
@@ -208,6 +219,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -304,6 +316,7 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const std::string & src,
|
||||
autoparser::generation_params & params);
|
||||
|
||||
|
||||
// specialized per-task preset
|
||||
struct common_chat_prompt_preset {
|
||||
std::string system;
|
||||
@@ -311,3 +324,6 @@ struct common_chat_prompt_preset {
|
||||
};
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
|
||||
@@ -445,6 +445,27 @@ std::string string_strip(const std::string & str) {
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
std::string string_lcs(std::string_view a, std::string_view b) {
|
||||
if (a.empty() || b.empty()) return {};
|
||||
|
||||
std::vector<std::vector<size_t>> dp(a.size() + 1, std::vector<size_t>(b.size() + 1, 0));
|
||||
size_t best_len = 0;
|
||||
size_t best_end_a = 0;
|
||||
|
||||
for (size_t i = 1; i <= a.size(); ++i) {
|
||||
for (size_t j = 1; j <= b.size(); ++j) {
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
dp[i][j] = dp[i - 1][j - 1] + 1;
|
||||
if (dp[i][j] > best_len) {
|
||||
best_len = dp[i][j];
|
||||
best_end_a = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string(a.substr(best_end_a - best_len, best_len));
|
||||
}
|
||||
|
||||
std::string string_get_sortable_timestamp() {
|
||||
using clock = std::chrono::system_clock;
|
||||
|
||||
|
||||
+2
-1
@@ -594,7 +594,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
@@ -731,6 +731,7 @@ std::string string_format(const char * fmt, ...);
|
||||
|
||||
std::string string_strip(const std::string & str);
|
||||
std::string string_get_sortable_timestamp();
|
||||
std::string string_lcs(std::string_view a, std::string_view b);
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
||||
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
||||
|
||||
@@ -81,6 +81,8 @@ static void test_normalize_quotes_with_embedded_quotes(testing & t);
|
||||
// TAG_WITH_TAGGED argument parsing tests
|
||||
static void test_tagged_args_with_embedded_quotes(testing & t);
|
||||
|
||||
static void test_role_markers_all_templates(testing & t);
|
||||
|
||||
int main(int argc, char * argv[]) {
|
||||
testing t(std::cout);
|
||||
t.verbose = true;
|
||||
@@ -103,6 +105,7 @@ int main(int argc, char * argv[]) {
|
||||
t.test("standard_json_tools", test_standard_json_tools_formats);
|
||||
t.test("normalize_quotes_to_json", test_normalize_quotes_to_json);
|
||||
t.test("tagged_args_embedded_quotes", test_tagged_args_with_embedded_quotes);
|
||||
t.test("role_markers_all_templates", test_role_markers_all_templates);
|
||||
|
||||
return t.summary();
|
||||
}
|
||||
@@ -714,7 +717,7 @@ static void test_compare_variants_both_modifiers(testing & t) {
|
||||
static void test_compare_variants_template_failure(testing & t) {
|
||||
// Test with template that causes failure during application (not construction)
|
||||
// We use a valid template syntax but one that will fail during application
|
||||
common_chat_template tmpl("{{ messages[0]['nonexistent_field'] }}", "", "");
|
||||
common_chat_template tmpl("{{ messages.cahoot()[0]['nonexistent_field'] }}", "", "");
|
||||
|
||||
template_params params;
|
||||
params.messages = json::array({
|
||||
@@ -1848,6 +1851,128 @@ static json build_edit_tool() {
|
||||
});
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Role marker detection tests for all autoparser-handled templates
|
||||
//
|
||||
// Verifies that detect_user_start_marker / detect_assistant_start_marker
|
||||
// return the correct boundary text between turns for every template that
|
||||
// falls through to the differential autoparser (i.e. is not handled by a
|
||||
// dedicated specialized template in common_chat_try_specialized_template).
|
||||
//
|
||||
// Markers were deduced manually from the jinja sources in models/templates/.
|
||||
// ============================================================================
|
||||
struct role_marker_case {
|
||||
std::string template_file;
|
||||
std::string expected_user_start;
|
||||
std::string expected_assistant_start;
|
||||
};
|
||||
|
||||
static void test_role_markers_all_templates(testing & t) {
|
||||
// Each entry is { template filename, user_start, assistant_start } as
|
||||
// produced when rendering the standard chatml-like sequences. The values
|
||||
// come from reading each jinja template and tracing what text precedes
|
||||
// a user/assistant message body once the autoparser strips any reasoning
|
||||
// markers it detected first.
|
||||
const std::vector<role_marker_case> cases = {
|
||||
// ChatML family: <|im_start|>{role} ... <|im_end|>
|
||||
{ "Bielik-11B-v3.0-Instruct.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "HuggingFaceTB-SmolLM3-3B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "MiMo-VL.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "Qwen3.5-4B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "Qwen3-Coder.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "Qwen-Qwen2.5-7B-Instruct.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "Qwen-Qwen3-0.6B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "Qwen-QwQ-32B.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "StepFun3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
{ "stepfun-ai-Step-3.5-Flash.jinja", "<|im_start|>user", "<|im_start|>assistant" },
|
||||
|
||||
// DeepSeek family
|
||||
{ "deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", "<|User|>", "<|Assistant|>" },
|
||||
{ "deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", "<|User|>", "<|Assistant|>" },
|
||||
{ "deepseek-ai-DeepSeek-V3.1.jinja", "<|User|>", "<|Assistant|>" },
|
||||
{ "llama-cpp-deepseek-r1.jinja", "<|User|>", "<|Assistant|>" },
|
||||
|
||||
// Llama 3 header family
|
||||
{ "meetkai-functionary-medium-v3.1.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" },
|
||||
{ "meta-llama-Llama-3.1-8B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" },
|
||||
{ "meta-llama-Llama-3.2-3B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" },
|
||||
{ "meta-llama-Llama-3.3-70B-Instruct.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" },
|
||||
// fireworks-ai forces a trailing assistant header even without add_generation_prompt,
|
||||
// so the marker is absorbed into the common suffix and assistant_start is detected as empty.
|
||||
{ "fireworks-ai-llama-3-firefunction-v2.jinja", "<|start_header_id|>user<|end_header_id|>", "<|start_header_id|>assistant<|end_header_id|>" },
|
||||
|
||||
// Phi/GLM/Apriel-style: <|user|> / <|assistant|>
|
||||
{ "microsoft-Phi-3.5-mini-instruct.jinja", "<|user|>", "<|assistant|>" },
|
||||
{ "GLM-4.6.jinja", "<|user|>", "<|assistant|>" },
|
||||
{ "unsloth-Apriel-1.5.jinja", "<|user|>", "<|assistant|>" },
|
||||
{ "GLM-4.7-Flash.jinja", "<|user|>", "<|assistant|>" },
|
||||
|
||||
// Gemma 2: <start_of_turn>{user|model}
|
||||
{ "google-gemma-2-2b-it.jinja", "<start_of_turn>user", "<start_of_turn>model" },
|
||||
|
||||
// IBM Granite
|
||||
{ "ibm-granite-granite-3.3-2B-Instruct.jinja", "<|start_of_role|>user<|end_of_role|>", "<|start_of_role|>assistant<|end_of_role|>" },
|
||||
{ "ibm-granite-granite-4.0.jinja", "<|start_of_role|>user<|end_of_role|>", "<|start_of_role|>assistant<|end_of_role|>" },
|
||||
|
||||
// Cohere R-series
|
||||
{ "CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja",
|
||||
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|START_RESPONSE|>" },
|
||||
{ "CohereForAI-c4ai-command-r-plus-tool_use.jinja",
|
||||
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" },
|
||||
|
||||
// Mistral: assistant content follows [/INST] immediately, no header
|
||||
{ "mistralai-Mistral-Nemo-Instruct-2407.jinja", "[INST]", "" },
|
||||
{ "Mistral-Small-3.2-24B-Instruct-2506.jinja", "[INST]", "" },
|
||||
|
||||
// Apertus uses <|user_start|> / <|assistant_start|> but the user diff
|
||||
// carries the preceding <|assistant_end|> from the previous turn.
|
||||
{ "Apertus-8B-Instruct.jinja", "<|user_start|>", "<|assistant_start|>" },
|
||||
|
||||
// Apriel 1.6 wraps the assistant body with <|begin_assistant|>, but
|
||||
// <|begin_assistant|> is also the detected reasoning start, so the
|
||||
// assistant_start is trimmed back to the preceding newline.
|
||||
{ "Apriel-1.6-15b-Thinker-fixed.jinja", "<|begin_user|>", "<|begin_assistant|>" },
|
||||
|
||||
// ByteDance Seed-OSS: <seed:bos>{role}
|
||||
{ "ByteDance-Seed-OSS.jinja", "<seed:bos>user", "<seed:bos>assistant" },
|
||||
|
||||
// GigaChat 3.1: {role}<|role_sep|>
|
||||
{ "GigaChat3.1-10B-A1.8B.jinja", "user<|role_sep|>", "assistant<|role_sep|>" },
|
||||
|
||||
// MiniMax M2: ]~b]{user|ai}
|
||||
{ "MiniMax-M2.jinja", "]~b]user", "]~b]ai" },
|
||||
|
||||
// Nemotron Nano v2: <SPECIAL_11>{User|Assistant}; assistant marker
|
||||
// is followed by a prefilled <think> block that gets included.
|
||||
{ "NVIDIA-Nemotron-Nano-v2.jinja", "<SPECIAL_11>User", "<SPECIAL_11>Assistant" },
|
||||
|
||||
// Reka Edge: "human: " / "assistant: " — but the rendered preamble
|
||||
// depends on enable_thinking, which currently confuses the user-start
|
||||
// diff and trims the marker down. Lock in the observed value.
|
||||
{ "Reka-Edge.jinja", "human:", "assistant:" },
|
||||
|
||||
// RWKV-world chat preset: "User: " / "Assistant: "
|
||||
{ "llama-cpp-rwkv-world.jinja", "User:", "Assistant:" },
|
||||
|
||||
// Upstage Solar 100B: <|begin|>{role}... but reasoning marker absorbs
|
||||
// the "<|begin|>assistant" prefix from assistant_start.
|
||||
{ "upstage-Solar-Open-100B.jinja", "<|begin|>user<|content|>", "<|begin|>assistant" },
|
||||
};
|
||||
|
||||
for (const auto & c : cases) {
|
||||
t.test(c.template_file, [&](testing & t) {
|
||||
common_chat_template tmpl = load_template(t, "models/templates/" + c.template_file);
|
||||
struct autoparser ap;
|
||||
ap.analyze_template(tmpl);
|
||||
t.assert_equal("user_start", c.expected_user_start, ap.user_start);
|
||||
t.assert_equal("assistant_start", c.expected_assistant_start, ap.assistant_start);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Test that reproduces the Seed-OSS template issue with embedded quotes
|
||||
static void test_tagged_args_with_embedded_quotes(testing & t) {
|
||||
json tools = build_edit_tool();
|
||||
|
||||
+39
-1
@@ -1548,6 +1548,40 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
}
|
||||
}
|
||||
|
||||
static void test_split_by_role() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
|
||||
// Empty inputs
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("hello", {}).size());
|
||||
assert_equals<size_t>(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size());
|
||||
|
||||
// Multi-role conversation, no leading/trailing content
|
||||
{
|
||||
const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye";
|
||||
const auto splits = common_chat_split_by_role(prompt, {
|
||||
{ "user", "<|user|>" },
|
||||
{ "assistant", "<|assistant|>" },
|
||||
});
|
||||
assert_equals<size_t>(3, splits.size());
|
||||
|
||||
assert_equals<std::string>("user", splits[0].role);
|
||||
assert_equals<size_t>(0, splits[0].pos);
|
||||
assert_equals<size_t>(10, splits[0].len);
|
||||
assert_equals<std::string>("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len));
|
||||
|
||||
assert_equals<std::string>("assistant", splits[1].role);
|
||||
assert_equals<size_t>(10, splits[1].pos);
|
||||
assert_equals<size_t>(18, splits[1].len);
|
||||
assert_equals<std::string>("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len));
|
||||
|
||||
assert_equals<std::string>("user", splits[2].role);
|
||||
assert_equals<size_t>(28, splits[2].pos);
|
||||
assert_equals<size_t>(11, splits[2].len);
|
||||
assert_equals<std::string>("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len));
|
||||
}
|
||||
}
|
||||
|
||||
static void test_tools_oaicompat_json_conversion() {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
std::vector<common_chat_tool> tools{
|
||||
@@ -4338,16 +4372,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
// Format: <TOOLCALL>[{"name": "func", "arguments": {...}}]</TOOLCALL>
|
||||
{
|
||||
auto tst = peg_tester("models/templates/NVIDIA-Nemotron-Nano-v2.jinja", detailed_debug);
|
||||
tst.test("<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL><SPECIAL_12>")
|
||||
tst.test("<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_assist_call)
|
||||
.run();
|
||||
|
||||
// Continuation tests
|
||||
tst.test("world!\nWhat's up?")
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.enable_thinking(true)
|
||||
.messages({ message_user, message_assist_prefill_content })
|
||||
.add_generation_prompt(false)
|
||||
.continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT)
|
||||
.expect_reasoning("I'm thinking")
|
||||
.expect_content("Hello, world!\nWhat's up?")
|
||||
.run();
|
||||
}
|
||||
@@ -5593,6 +5630,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
test_msg_diffs_compute();
|
||||
test_msgs_oaicompat_json_conversion();
|
||||
test_split_by_role();
|
||||
test_tools_oaicompat_json_conversion();
|
||||
test_convert_responses_to_chatcmpl();
|
||||
test_developer_role_to_system_workaround();
|
||||
|
||||
@@ -147,7 +147,6 @@
|
||||
| `--display-prompt, --no-display-prompt` | whether to print prompt at generation (default: true) |
|
||||
| `-co, --color [on\|off\|auto]` | Colorize output to distinguish prompt and user input from generations ('on', 'off', or 'auto', default: 'auto')<br/>'auto' enables colors when output is to a terminal |
|
||||
| `-ctxcp, --ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 32)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
|
||||
| `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)<br/>(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) |
|
||||
| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
|
||||
| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
|
||||
| `-sys, --system-prompt PROMPT` | system prompt to use with model (if applicable, depending on chat template) |
|
||||
|
||||
@@ -163,7 +163,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `-lcs, --lookup-cache-static FNAME` | path to static lookup cache to use for lookup decoding (not updated by generation) |
|
||||
| `-lcd, --lookup-cache-dynamic FNAME` | path to dynamic lookup cache to use for lookup decoding (updated by generation) |
|
||||
| `-ctxcp, --ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 32)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
|
||||
| `-cpent, --checkpoint-every-n-tokens N` | create a checkpoint every n tokens during prefill (processing), -1 to disable (default: 8192)<br/>(env: LLAMA_ARG_CHECKPOINT_EVERY_NT) |
|
||||
| `-cms, --checkpoint-min-step N` | minimum spacing between context checkpoints in tokens (default: 256, 0 = no minimum)<br/>(env: LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT) |
|
||||
| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
|
||||
| `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
|
||||
| `--cache-idle-slots, --no-cache-idle-slots` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)<br/>(env: LLAMA_ARG_CACHE_IDLE_SLOTS) |
|
||||
|
||||
@@ -1110,6 +1110,16 @@ json oaicompat_chat_params_parse(
|
||||
llama_params["chat_parser"] = chat_params.parser;
|
||||
}
|
||||
|
||||
llama_params["message_spans"] = json::array();
|
||||
|
||||
for (const auto & span : chat_params.message_spans) {
|
||||
llama_params["message_spans"].push_back({
|
||||
{ "role", span.role },
|
||||
{ "pos", span.pos },
|
||||
{ "len", span.len },
|
||||
});
|
||||
}
|
||||
|
||||
// Reasoning budget: pass parameters through to sampling layer
|
||||
{
|
||||
int reasoning_budget = opt.reasoning_budget;
|
||||
|
||||
+110
-25
@@ -1103,6 +1103,13 @@ private:
|
||||
}
|
||||
SRV_INF("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n");
|
||||
|
||||
if (params_base.n_ctx_checkpoints > 0) {
|
||||
SRV_INF("context checkpoints enabled, max = %d, min spacing = %d\n",
|
||||
params_base.n_ctx_checkpoints, params_base.checkpoint_min_step);
|
||||
} else {
|
||||
SRV_INF("%s", "context checkpoints disabled\n");
|
||||
}
|
||||
|
||||
if (!params_base.model_alias.empty()) {
|
||||
// backward compat: use first alias as model name
|
||||
model_name = *params_base.model_alias.begin();
|
||||
@@ -2758,8 +2765,6 @@ private:
|
||||
}
|
||||
|
||||
if (pos_min >= pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
const auto it = std::find_if(
|
||||
slot.prompt.checkpoints.rbegin(),
|
||||
@@ -2776,7 +2781,6 @@ private:
|
||||
|
||||
if (!do_reset) {
|
||||
// restore the context checkpoint
|
||||
|
||||
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
@@ -2912,6 +2916,9 @@ private:
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
const int32_t n_before_user = slot.task->params.n_before_user;
|
||||
const bool n_before_user_known = n_before_user > 0;
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
@@ -2940,6 +2947,13 @@ private:
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// stop the prompt batch exactly before the latest user input, so a checkpoint
|
||||
// can be created after the previous messages
|
||||
if (n_before_user_known &&
|
||||
slot.prompt.n_tokens() == n_before_user) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
// create checkpoints that many tokens before the end of the prompt:
|
||||
// - 4 + n_ubatch
|
||||
@@ -2965,6 +2979,8 @@ private:
|
||||
// the number of tokens added to the batch for the current slot
|
||||
const auto n_tokens_cur = batch.n_tokens - n_tokens_prev;
|
||||
|
||||
const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch;
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
@@ -2979,39 +2995,49 @@ private:
|
||||
|
||||
slot.init_sampler();
|
||||
} else {
|
||||
if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
|
||||
// near the end of the prompt
|
||||
do_checkpoint = do_checkpoint && true;
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
// skip ordinary mid-prompt checkpoints
|
||||
if (!n_before_user_known && !near_prompt_end) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id);
|
||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||
// checkpoints are created before the current batch is decoded, so
|
||||
// their token position is the batch start rather than the prompt end
|
||||
const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur;
|
||||
|
||||
{
|
||||
const bool is_on_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start == n_before_user;
|
||||
|
||||
const bool is_after_user =
|
||||
n_before_user_known &&
|
||||
n_tokens_start > n_before_user;
|
||||
|
||||
const bool is_allowed =
|
||||
!n_before_user_known ||
|
||||
is_on_user ||
|
||||
(is_after_user && near_prompt_end);
|
||||
|
||||
if (do_checkpoint && !is_allowed) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
}
|
||||
|
||||
// nothing to checkpoint yet
|
||||
// TODO: is this check needed?
|
||||
if (do_checkpoint && pos_min < 0) {
|
||||
do_checkpoint = false;
|
||||
}
|
||||
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step);
|
||||
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
@@ -3528,6 +3554,53 @@ void server_context::on_sleeping_changed(std::function<void(bool)> callback) {
|
||||
impl->queue_tasks.on_sleeping_state(std::move(callback));
|
||||
}
|
||||
|
||||
// compute the number of tokens before the last user message in the prompt
|
||||
static int32_t prompt_get_n_before_user(
|
||||
const json & message_spans,
|
||||
const std::string & prompt,
|
||||
const std::vector<raw_buffer> & files,
|
||||
const llama_vocab * vocab,
|
||||
mtmd_context * mctx) {
|
||||
int32_t result = -1;
|
||||
int32_t byte_pos = -1;
|
||||
|
||||
for (const auto & span : message_spans) {
|
||||
const std::string role = json_value(span, "role", std::string());
|
||||
|
||||
if (role == "user") {
|
||||
byte_pos = json_value(span, "pos", -1);
|
||||
}
|
||||
}
|
||||
|
||||
if (byte_pos >= 0) {
|
||||
GGML_ASSERT((size_t) byte_pos <= prompt.size());
|
||||
|
||||
const std::string prefix = prompt.substr(0, (size_t) byte_pos);
|
||||
|
||||
const std::string marker = get_media_marker();
|
||||
size_t n_prefix_media = 0;
|
||||
for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) {
|
||||
n_prefix_media++;
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_prefix_media <= files.size());
|
||||
|
||||
if (mctx != nullptr && n_prefix_media > 0) {
|
||||
// TODO: this makes a copy - avoid it
|
||||
std::vector<raw_buffer> prefix_files(files.begin(), files.begin() + n_prefix_media);
|
||||
|
||||
result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size();
|
||||
} else {
|
||||
result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size();
|
||||
}
|
||||
|
||||
SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n",
|
||||
byte_pos, n_prefix_media, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// server_routes
|
||||
@@ -3577,6 +3650,18 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||
meta->slot_n_ctx,
|
||||
meta->logit_bias_eog,
|
||||
data);
|
||||
|
||||
const auto message_spans = json_value(data, "message_spans", json::array());
|
||||
if (prompt.is_string() && message_spans.is_array()) {
|
||||
task.params.n_before_user =
|
||||
prompt_get_n_before_user(
|
||||
message_spans,
|
||||
prompt.get<std::string>(),
|
||||
files,
|
||||
ctx_server.vocab,
|
||||
ctx_server.mctx);
|
||||
}
|
||||
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
|
||||
@@ -61,6 +61,9 @@ struct task_params {
|
||||
|
||||
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
|
||||
|
||||
// number of prompt tokens before the latest user message
|
||||
int32_t n_before_user = -1;
|
||||
|
||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
|
||||
Reference in New Issue
Block a user