Compare commits

...

9 Commits

Author SHA1 Message Date
Xuan-Son Nguyen bddfd2b113 server: refactor batch construction (#24843)
* server: refactor batch construction

* wip

* wip 2

* wip 3

* wip 4

* add abort_all_slots

* handle batch full more carefully

* fix assert

* rm debug log

* small nits

* (debug) add timings

* debug: force llama_synchronize for accurate timings

* address comments

* disable DEBUG_TIMINGS
2026-06-21 14:16:11 +02:00
Xuan-Son Nguyen 0d135df48c mtmd: fix mtmd_get_memory_usage (#24867) 2026-06-21 14:12:15 +02:00
Sigbjørn Skjæret bf533823cd jinja : implement call statement (#24847)
* implement call statement

* undo unintended change

* de-lambda

* simplify

* move caller context inside function handler
2026-06-21 14:04:52 +02:00
Xuan-Son Nguyen 2f89acc2bc mtmd: add load progress callback (#24865) 2026-06-21 13:40:52 +02:00
Xuan-Son Nguyen bfa3219177 server: add "verbose" field to schema (#24864) 2026-06-21 13:03:14 +02:00
Xuan-Son Nguyen d6d899580d server: real-time model load progress tracking via /models/sse (#24828)
* server: real-time model load progress tracking via /models/sse

* update docs

* add mutex for notify_to_router

* correct docs
2026-06-21 11:58:14 +02:00
Georgi Gerganov 8a118ee86c minor : clean-up whitespaces (#24862)
[no ci]
2026-06-21 11:37:12 +03:00
YiChen Lv d789527482 spec : Support Step3.5/3.7 flash mtp3 (#24340)
* add mtp_layer_offset + include nextn flags in graph reuse

* add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API

* offset head select + require all MTP blocks

* speculative multi-head process()

* speculative multi-head draft()

* gather outputs via inp_out_ids

* cleanup

* fix core

* minor cleanup

* merged draft_multi_head into draft()

* mtp rename nextn

* Apply suggestions from code review

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* clean-up comments

* fix for multi seq

* apply suggestions && chain-heads comment

* add a reference for chain_heads discussion

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-06-21 11:33:18 +03:00
Aldehir Rojas 063d9c156e common/peg : refactor until gbnf grammar generation (#24839)
* common/peg : refactor until gbnf grammar into an ac automaton

* cont : add a test with multiple strings

* cont : pad state with 0s so rules line up

* cont : clean up comments

* cont : use set everywhere

* cont : inline state num string padding

* cont : add a ref to PR

* cont : fix regression in server-tools.cpp
2026-06-20 21:15:06 -05:00
26 changed files with 1279 additions and 583 deletions
+89 -46
View File
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
const size_t expected_count = this_args.size();
const size_t input_count = args.count();
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this_args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in '" + name + "'");
}
} else {
auto & default_arg = this_args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
} else {
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
}
value macro_statement::execute_impl(context & ctx) {
if (!is_stmt<identifier>(this->name)) {
throw std::runtime_error("Macro name must be an identifier");
}
std::string name = cast_stmt<identifier>(this->name)->val;
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
size_t expected_count = this->args.size();
size_t input_count = args.count();
const func_handler func = [this, name](const func_args & args) -> value {
context macro_ctx(args.ctx); // new scope for macro execution
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
context macro_ctx(ctx); // new scope for macro execution
// bind parameters
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this->args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
}
} else {
auto & default_arg = this->args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
} else {
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
bind_parameters(name, this->args, args, macro_ctx);
// execute macro body
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
value call_statement::execute_impl(context & ctx) {
auto call_expr = cast_stmt<call_expression>(this->call);
if (!call_expr) {
throw std::runtime_error("Call statement requires a valid call expression");
}
value callee_val = call_expr->callee->execute(ctx);
if (!is_val<value_func>(callee_val)) {
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
}
auto * callee_func = cast_val<value_func>(callee_val);
context caller_ctx(ctx); // new scope for caller execution
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
context block_ctx(caller_ctx); // new scope for block execution
bind_parameters("caller", this->caller_args, args, block_ctx);
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
auto res = exec_statements(this->body, block_ctx);
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
return res;
};
context call_ctx(ctx);
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
func_args args(call_ctx);
for (const auto & arg_expr : call_expr->args) {
auto arg_val = arg_expr->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.push_back(arg_val);
}
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
return callee_func->invoke(args);
}
value member_expression::execute_impl(context & ctx) {
value object = this->object->execute(ctx);
+1
View File
@@ -552,6 +552,7 @@ struct call_statement : public statement {
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
};
struct ternary_expression : public expression {
+117 -77
View File
@@ -6,13 +6,14 @@
#include "unicode.h"
#include <algorithm>
#include <deque>
#include <initializer_list>
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <regex>
#include <set>
#include <stdexcept>
#include <unordered_set>
// Trick to catch missing branches
template <typename T>
@@ -88,40 +89,7 @@ struct trie {
return match_result{match_result::NO_MATCH};
}
struct prefix_and_next {
std::vector<uint32_t> prefix;
std::vector<uint32_t> next_chars;
};
std::vector<prefix_and_next> collect_prefix_and_next() {
std::vector<uint32_t> prefix;
std::vector<prefix_and_next> result;
collect_prefix_and_next(0, prefix, result);
return result;
}
private:
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
if (!nodes[index].is_word) {
if (!nodes[index].children.empty()) {
std::vector<uint32_t> chars;
chars.reserve(nodes[index].children.size());
for (const auto & p : nodes[index].children) {
chars.push_back(p.first);
}
out.emplace_back(prefix_and_next{prefix, chars});
}
}
for (const auto & p : nodes[index].children) {
uint32_t ch = p.first;
auto child = p.second;
prefix.push_back(ch);
collect_prefix_and_next(child, prefix, out);
prefix.pop_back();
}
}
size_t create_node() {
size_t index = nodes.size();
nodes.emplace_back();
@@ -153,6 +121,65 @@ struct trie {
}
};
// Aho-Corasick automaton
struct aho_corasick {
trie t;
std::vector<size_t> fail; // failure links
std::vector<size_t> order; // states in BFS order
std::vector<bool> terminal; // match states (directly or via a suffix link)
std::set<uint32_t> alphabet; // every character with a transition
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
const auto & nodes = t.nodes;
const size_t n = nodes.size();
fail.assign(n, 0);
order.reserve(n);
std::deque<size_t> queue{ 0 };
while (!queue.empty()) {
size_t u = queue.front();
queue.pop_front();
order.push_back(u);
for (const auto & [ch, v] : nodes[u].children) {
if (u != 0) {
size_t f = fail[u];
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
f = fail[f];
}
auto it = nodes[f].children.find(ch);
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
}
queue.push_back(v);
}
}
terminal.assign(n, false);
for (size_t u : order) {
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
}
for (const auto & node : nodes) {
for (const auto & [ch, v] : node.children) {
alphabet.insert(ch);
}
}
}
size_t num_states() const { return t.nodes.size(); }
bool is_terminal(size_t s) const { return terminal[s]; }
// follow failure links until a transition on `ch` exists.
size_t next(size_t state, uint32_t ch) const {
const auto & nodes = t.nodes;
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
state = fail[state];
}
auto it = nodes[state].children.find(ch);
return it != nodes[state].children.end() ? it->second : 0;
}
};
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
if (pos + hex_count > str.length()) {
return {0, 0};
@@ -992,12 +1019,12 @@ void common_peg_arena::resolve_refs() {
}
std::string common_peg_arena::dump(common_peg_parser_id id) const {
std::unordered_set<common_peg_parser_id> visited;
std::set<common_peg_parser_id> visited;
return dump_impl(id, visited);
}
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
std::unordered_set<common_peg_parser_id> & visited) const {
std::set<common_peg_parser_id> & visited) const {
// Check for cycles
if (visited.count(id)) {
return "[cycle]";
@@ -1502,61 +1529,74 @@ static std::string gbnf_escape_char_class(uint32_t c) {
return std::string(buf);
}
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
trie matcher(strings);
auto pieces = matcher.collect_prefix_and_next();
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
aho_corasick ac(strings);
std::string pattern;
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
for (size_t i = 0; i < pieces.size(); ++i) {
if (i > 0) {
pattern += " | ";
auto state_name = [&](size_t s) -> std::string {
if (s == 0) {
return prefix;
}
std::string num = std::to_string(s);
num = num.size() == 1 ? ("0" + num) : num;
return prefix + "-" + num;
};
const auto & pre = pieces[i].prefix;
const auto & chars = pieces[i].next_chars;
std::string cls;
cls.reserve(chars.size());
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
cls += gbnf_escape_char_class(ch);
s += gbnf_escape_char_class(ch);
}
return s + "]";
};
for (size_t q = 0; q < ac.num_states(); q++) {
if (ac.is_terminal(q)) {
continue; // match states are dropped
}
if (!pre.empty()) {
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
pattern += pre_literal + " [^" + cls + "]";
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
// char, so the repetition alone cannot match a value that *ends* on a proper
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
// values, so without this the grammar would reject input the parser accepts.
// Allow the value to terminate on any proper prefix as an optional tail.
// This makes the grammar a slight superset of the runtime language (a value
// may end on the longest prefix, which greedy first-match would not itself
// produce); harmless for constrained generation, which only needs to admit
// every runtime-valid string.
if (!trailing.empty()) {
trailing += " | ";
std::map<size_t, std::vector<uint32_t>> buckets;
std::vector<uint32_t> excluded;
for (uint32_t c : ac.alphabet) {
size_t d = ac.next(q, c);
if (ac.is_terminal(d)) {
excluded.push_back(c); // completes a forbidden string -> omit
} else if (d != 0) {
buckets[d].push_back(c); // specific non-root destination
excluded.push_back(c);
}
trailing += pre_literal;
} else {
pattern += "[^" + cls + "]";
}
std::string rhs = "|"; // every state is accepting
for (const auto & [d, chars] : buckets) {
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + char_class(excluded, true) + " " + state_name(0);
builder.add_rule(state_name(q), rhs);
}
std::string result = "(" + pattern + ")*";
if (!trailing.empty()) {
result += " (" + trailing + ")?";
// An empty delimiter makes the start state terminal. Emit an entry rule
// that matches nothing so the returned reference stays valid.
if (ac.is_terminal(0)) {
builder.add_rule(prefix, "|");
}
return result;
return state_name(0);
}
static std::unordered_set<std::string> collect_reachable_rules(
static std::set<std::string> collect_reachable_rules(
const common_peg_arena & arena,
const common_peg_parser_id & rule
) {
std::unordered_set<std::string> reachable;
std::unordered_set<std::string> visited;
std::set<std::string> reachable;
std::set<std::string> visited;
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
const auto & parser = arena.get(id);
@@ -1765,7 +1805,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
if (p.delimiters.empty()) {
return ".*";
}
return gbnf_excluding_pattern(p.delimiters);
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
if (schema_delegates(p)) {
return to_gbnf(p.child);
@@ -1789,7 +1829,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
};
// Collect reachable rules
std::unordered_set<std::string> reachable_rules;
std::set<std::string> reachable_rules;
if (lazy) {
// Collect rules reachable from trigger rules
+2 -2
View File
@@ -3,8 +3,8 @@
#include <nlohmann/json_fwd.hpp>
#include <memory>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <string_view>
#include <functional>
@@ -335,7 +335,7 @@ class common_peg_arena {
friend class common_peg_parser_builder;
private:
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
common_peg_parser_id add_parser(common_peg_parser_variant parser);
void add_rule(const std::string & name, common_peg_parser_id id);
+102 -35
View File
@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// One MTP draft driver, three modes (set once in the ctor):
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
// neither (qwen35 / qwen35moe): a single trained MTP head.
int32_t n_mtp_layers = 1;
bool is_mem_shared = false; // gemma4
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
// Per-seq draft length from the last draft() call, used in accept() to
// roll back ctx_dft's recurrent state past the AR draft's redundant
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
std::vector<int> i_last;
std::vector<std::vector<float>> chain_h;
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
if (chain_heads) {
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
chain_h.assign(n_seq, {});
for (auto & c : chain_h) {
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
}
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_last.assign(n_seq, -1);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_impl_draft_mtp() override {
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
auto * mem_dft = llama_get_memory(ctx_dft);
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
}
llama_set_nextn_layer_offset(ctx_dft, head);
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
if (!ok) {
return false;
}
}
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
i_last[seq_id] = batch.n_tokens - 1;
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
if (chain_heads) {
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
}
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
// each step decodes under a different head, i.e. a different decoder layer, and
// KV is per layer. process() filled this layer's KV only for positions < n_past
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
// and select head i so it rebuilds its own layer's KV there; decoding just the
// latest token would leave its attention reading cells only another head wrote.
if (chain_heads) {
auto * mem_dft = llama_get_memory(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (drafting[seq_id]) {
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
}
}
llama_set_nextn_layer_offset(ctx_dft, i);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
// rebuild the batch for the next step: the growing-KV paths re-add only the
// new token (the KV already holds the prefix), while chained heads re-add the
// whole prefix at the next head. dropped sequences are simply not re-added.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
if (is_mem_shared) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
for (int t = 0; t < n_rows; ++t) {
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
}
} else if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
i_last[seq_id] = batch.n_tokens - 1;
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
}
}
@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}
+9 -8
View File
@@ -558,14 +558,15 @@ extern "C" {
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
+8
View File
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
sched_need_reserve = true;
}
void llama_context::set_nextn_layer_offset(int32_t offset) {
cparams.nextn_layer_offset = offset;
}
void llama_context::set_causal_attn(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
ctx->set_embeddings_layer_inp(lid, value);
}
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
ctx->set_nextn_layer_offset(offset);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;
+1
View File
@@ -115,6 +115,7 @@ struct llama_context {
void set_embeddings (bool value);
void set_embeddings_nextn(bool value, bool masked);
void set_embeddings_layer_inp(uint32_t lid, bool enable);
void set_nextn_layer_offset(int32_t offset);
void set_causal_attn(bool value);
void set_warmup(bool value);
+2
View File
@@ -18,6 +18,8 @@ struct llama_cparams {
int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
int32_t nextn_layer_offset = 0;
float rope_freq_base;
float rope_freq_scale;
+5
View File
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
// Select which appended NextN block the DECODER_MTP graph runs (offset past
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
// chain multiple trained NextN heads. Default 0 (first head).
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
+9 -2
View File
@@ -682,9 +682,16 @@ struct llm_graph_params {
}
}
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
cparams.embeddings == other.cparams.embeddings &&
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
+4
View File
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
return model->hparams.n_layer();
}
int32_t llama_model_n_layer_nextn(const llama_model * model) {
return model->hparams.n_layer_nextn;
}
int32_t llama_model_n_head(const llama_model * model) {
return model->hparams.n_head();
}
+27 -28
View File
@@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
};
auto load_block_mtp = [&](int i, bool is_first_mtp) {
auto load_block_mtp = [&](int i) {
auto & layer = layers[i];
const uint32_t n_head_l = hparams.n_head(i);
@@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
//
// Only the FIRST MTP block (i == n_main) is required for the
// single-block MTP runtime; trailing MTP blocks are always tolerated
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
// mtp_flags to NOT_REQUIRED for those.
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
// Multi-block MTP: every declared MTP block is required (the draft chain
// runs all n_layer_nextn heads), so each block uses the captured
// `mtp_flags` directly — already NOT_REQUIRED for a trunk-only GGUF,
// which keeps that path correct.
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
@@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
}
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
@@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
// NextN-specific tensors that define the MTP block.
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
@@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
load_block_trunk(i, trunk_flags);
}
// Only the first MTP block (i == n_main) is required at runtime — the
// single-block-MTP graph in build_arch_graph always uses that one.
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
// all MTP layers still works) but tolerated when absent via the pruning
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
// All n_layer_nextn MTP blocks are required — the multi-block draft chain
// runs every head (head k at offset k). The GGUF declares the count via
// step35.nextn_predict_layers.
for (int i = n_layer; i < n_layer_all; ++i) {
load_block_mtp(i, /*is_first_mtp=*/ i == n_layer);
load_block_mtp(i);
}
}
@@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
: llm_graph_context(params) {
GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0");
// Single-block MTP only: always run the first trained MTP block (Qwen
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
// be a much deeper refactor than this PR justifies; the trailing MTP
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
// block 0) also work — see load_arch_tensors below and
// scripts/prune_step35_extra_mtp.py.
const int il = hparams.n_layer();
// Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by
// cparams.nextn_layer_offset (0 = first trained head). The speculative driver
// bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps
// single-block behavior identical to before.
const int il = hparams.n_layer() + cparams.nextn_layer_offset;
GGML_ASSERT(cparams.nextn_layer_offset >= 0 &&
cparams.nextn_layer_offset < (int) hparams.n_layer_nextn &&
"nextn_layer_offset out of range [0, n_layer_nextn)");
const auto & layer = model.layers[il];
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");
@@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "mtp_post_ffn", il);
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
+79 -1
View File
@@ -129,8 +129,86 @@ void test_gbnf_generation(testing &t) {
});
assert_gbnf_equal(t, R"""(
root ::= ([^<] | "<" [^/] | "</" [^t] | "</t" [^a] | "</ta" [^g] | "</tag" [^>])* ("<" | "</" | "</t" | "</ta" | "</tag")?
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [<] until-0-01 | [^<] until-0
until-0-01 ::= | [<] until-0-01 | [/] until-0-02 | [^/<] until-0
until-0-02 ::= | [<] until-0-01 | [t] until-0-03 | [^<t] until-0
until-0-03 ::= | [<] until-0-01 | [a] until-0-04 | [^<a] until-0
until-0-04 ::= | [<] until-0-01 | [g] until-0-05 | [^<g] until-0
until-0-05 ::= | [<] until-0-01 | [^<>] until-0
)""", gbnf);
});
t.test("until grammar overlapping delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until("\n</parameter>\n");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [\n] until-0-01 | [^\n] until-0
until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0
until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0
until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0
until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0
until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0
until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0
until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0
until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0
until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0
until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0
until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0
until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0
until-0-13 ::= | [^\n] until-0
)""", gbnf);
});
// DeepSeek-V3.2 tag prefix. The DSML token (DSML) embeds U+FF5C,
// so the delimiter mixes ASCII and multi-byte codepoints.
t.test("until grammar unicode delimiter", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until("<DSML");
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [<] until-0-01 | [^<] until-0
until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0
until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^<D] until-0
until-0-03 ::= | [<] until-0-01 | [S] until-0-04 | [^<S] until-0
until-0-04 ::= | [<] until-0-01 | [M] until-0-05 | [^<M] until-0
until-0-05 ::= | [<] until-0-01 | [L] until-0-06 | [^<L] until-0
until-0-06 ::= | [<] until-0-01 | [^<\uFF5C] until-0
)""", gbnf);
});
t.test("until grammar multiple delimiters", [](testing &t) {
auto parser = build_peg_parser([](common_peg_parser_builder & p) {
return p.until_one_of({"ab", "cd", "ef"});
});
auto gbnf = build_grammar([&](const common_grammar_builder & builder) {
parser.build_grammar(builder);
});
assert_gbnf_equal(t, R"""(
root ::= until-0
space ::= | " " | "\n"{1,2} [ \t]{0,20}
until-0 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^ace] until-0
until-0-01 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^abce] until-0
until-0-03 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acde] until-0
until-0-05 ::= | [a] until-0-01 | [c] until-0-03 | [e] until-0-05 | [^acef] until-0
)""", gbnf);
});
+26
View File
@@ -995,6 +995,32 @@ static void test_macros(testing & t) {
json::object(),
"Hello, John Smith,Hi, Jane Doe"
);
test_template(t, "macro with caller",
"\
{%- macro nest_dict(o, i, ff='') %}\n\
{{- caller(ff) }}\n\
{%- for k, v in o|items %}\n\
{{- i + k + ': ' }}\n\
{%- if v is mapping %}\n\
{{- '{' }}\n\
{% call(f) nest_dict(v, i + ' ') %}\n\
{{- 'fail' if ff is undefined }}\n\
{%- endcall %}\n\
{{- i + '}' }}\n\
{% else %}\n\
{{- v|string }}\n\
{% endif %}\n\
{%- endfor %}\n\
{%- endmacro %}\n\
{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\
{{- 'fail' if ff is defined }}\n\
{{- f + ' {' }}\n\
{% endcall %}\n\
{{- '}' }}",
json::object(),
"Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}"
);
}
static void test_namespace(testing & t) {
+59 -24
View File
@@ -1045,8 +1045,17 @@ struct clip_model_loader {
bool has_vision = false;
bool has_audio = false;
mtmd_progress_callback progress_callback = nullptr;
void * progress_callback_user_data = nullptr;
// TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) {
clip_model_loader(const char * fname,
bool skip_tensors = false,
mtmd_progress_callback progress_cb = nullptr,
void * progress_user_data = nullptr)
: fname(fname),
progress_callback(progress_cb),
progress_callback_user_data(progress_user_data) {
struct ggml_context * meta = nullptr;
struct gguf_init_params params = {
@@ -2787,37 +2796,60 @@ struct clip_model_loader {
}
// load data
if (!ctx_clip.no_alloc) {
{
std::vector<uint8_t> read_buf;
// start loading event
if (progress_callback){
progress_callback(0.0, progress_callback_user_data);
}
// compute total tensor data size for progress reporting
size_t total_data_size = 0;
for (auto & t : tensors_to_load) {
total_data_size += ggml_nbytes(t);
}
// alloc memory and offload data
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
}
size_t num_bytes = ggml_nbytes(cur);
if (ggml_backend_buft_is_host(buft)) {
// for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(num_bytes);
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
// read the weight from file
if (!ctx_clip.no_alloc) {
size_t data_loaded = 0;
for (auto & t : tensors_to_load) {
ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
GGML_ASSERT(cur && "tensor not found in ctx_data");
auto it_off = tensor_offset.find(t->name);
GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor");
const size_t offset = it_off->second;
fin.seekg(offset, std::ios::beg);
if (!fin) {
throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
}
size_t num_bytes = ggml_nbytes(cur);
if (ggml_backend_buft_is_host(buft)) {
// for the CPU and Metal backend, we can read directly into the tensor
fin.read(reinterpret_cast<char *>(cur->data), num_bytes);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(num_bytes);
fin.read(reinterpret_cast<char *>(read_buf.data()), num_bytes);
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
}
data_loaded += num_bytes;
if (progress_callback && total_data_size > 0) {
const float progress = (float)data_loaded / (float)total_data_size;
if (!progress_callback(progress, progress_callback_user_data)) {
throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__));
}
}
}
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
} else {
LOG_DBG("%s: no_alloc is set, skipping tensor data loading (%zu tensors)\n", __func__, tensors_to_load.size());
}
fin.close();
LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
}
}
@@ -3105,7 +3137,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
clip_ctx * ctx_audio = nullptr;
try {
clip_model_loader loader(fname);
clip_model_loader loader(fname,
/* skip_tensors */ false,
ctx_params.progress_callback,
ctx_params.progress_callback_user_data);
bool skip_audio = false;
if (loader.has_vision) {
+2
View File
@@ -54,6 +54,8 @@ struct clip_context_params {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
bool no_alloc;
mtmd_progress_callback progress_callback;
void * progress_callback_user_data;
};
struct clip_init_result {
+8 -1
View File
@@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() {
/* cb_eval */ nullptr,
/* cb_eval_user_data */ nullptr,
/* batch_max_tokens */ 1024,
/* progress_callback */ nullptr,
/* progress_callback_user_data */ nullptr,
};
return params;
}
@@ -345,6 +347,8 @@ struct mtmd_context {
/* cb_eval */ ctx_params.cb_eval,
/* cb_eval_user_data */ ctx_params.cb_eval_user_data,
/* no_alloc */ no_alloc,
/* progress_callback */ ctx_params.progress_callback,
/* progress_callback_user_data */ ctx_params.progress_callback_user_data,
};
auto res = clip_init(mmproj_fname, ctx_clip_params);
@@ -2133,9 +2137,12 @@ std::map<ggml_backend_dev_t, size_t> mtmd_get_memory_usage(const char * mmproj_f
mtmd::context_ptr ctx;
auto saved_log_callback = g_logger_state.log_callback;
auto saved_log_user_data = g_logger_state.log_callback_user_data;
ctx_params.progress_callback = nullptr;
try {
mtmd_log_set(stub_log_callback, nullptr); // suppress logging
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params));
ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params, true));
mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback
std::map<ggml_backend_dev_t, size_t> total_mem;
auto merge = [&](const struct clip_ctx * c) {
+8
View File
@@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks;
typedef struct mtmd_input_text mtmd_input_text;
typedef struct mtmd_batch mtmd_batch;
typedef bool (*mtmd_progress_callback)(float progress, void * user_data);
struct mtmd_context_params {
bool use_gpu;
bool print_timings;
@@ -104,6 +106,12 @@ struct mtmd_context_params {
int32_t batch_max_tokens; // maximum number of output tokens in a batch
// (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit)
// (default: 1024)
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted.
mtmd_progress_callback progress_callback;
void * progress_callback_user_data;
};
MTMD_API const char * mtmd_default_marker(void);
+26 -2
View File
@@ -1859,9 +1859,33 @@ Example events:
{
"model": "...",
"event": "download_finished",
"event": "model_status",
"data": {
"status": "loading"
"status": "loading",
"progress": {
"stage": "fit_params",
"value": 0.5 // from 0.0 to 1.0 ; note: not all stages have this "value"
}
}
}
{
"model": "...",
"event": "model_status",
"data": {
"status": "loaded",
"info": {
// note: only include info on first load
// waking up from sleep doesn't have this
}
}
}
{
"model": "...",
"event": "model_status",
"data": {
"status": "sleeping"
}
}
File diff suppressed because it is too large Load Diff
+17 -3
View File
@@ -442,6 +442,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@@ -608,6 +609,7 @@ void server_models::load_models() {
/* last_used */ 0,
/* args */ std::vector<std::string>(),
/* loaded_info */ {},
/* progress */ {},
/* exit_code */ 0,
/* stop_timeout */ DEFAULT_STOP_TIMEOUT,
/* multimodal */ mtmd_caps{false, false},
@@ -1140,6 +1142,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
meta.loaded_info = args.loaded_info;
}
if (!args.progress.is_null()) {
meta.progress = args.progress;
}
}
// broadcast status change to SSE
{
@@ -1152,6 +1157,9 @@ void server_models::update_status(const std::string & name, const update_status_
if (!args.loaded_info.is_null()) {
data["info"] = args.loaded_info;
}
if (!args.progress.is_null()) {
data["progress"] = args.progress;
}
// note: notify_sse doesn't acquire the lock, so no deadlock here
notify_sse("status_change", name, data);
}
@@ -1322,8 +1330,12 @@ void server_models::handle_child_state(const std::string & name, const std::stri
switch (state) {
case SERVER_STATE_LOADING:
{
// do nothing for now
// TODO: report loading progress for first load and wakeup from sleep
update_status(name, {
SERVER_MODEL_STATUS_LOADING,
0,
nullptr, // no loaded_info yet
payload,
});
} break;
case SERVER_STATE_READY:
{
@@ -1331,7 +1343,8 @@ void server_models::handle_child_state(const std::string & name, const std::stri
SERVER_MODEL_STATUS_LOADED,
0,
// note: payload can be empty if this is a wakeup from sleep
payload.size() > 0 ? payload : nullptr
payload.size() > 0 ? payload : nullptr,
{}, // reset progress info
});
} break;
case SERVER_STATE_SLEEPING:
@@ -1384,6 +1397,7 @@ void server_child::notify_to_router(const std::string & state, const json & payl
{"state", state},
{"payload", payload},
};
std::lock_guard<std::mutex> lk(mtx_stdout);
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str());
+7 -1
View File
@@ -72,6 +72,7 @@ struct server_model_meta {
int64_t last_used = 0; // for LRU unloading
std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info
json progress; // reflect load or download progress info, if any
int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
mtmd_caps multimodal; // multimodal capabilities
@@ -170,12 +171,14 @@ public:
// to stop the download, call unload()
void download(common_params_model && model, common_download_opts && opts);
// update the status of a model instance (thread-safe)
struct update_status_args {
server_model_status status;
int exit_code = 0; // only valid if status == UNLOADED
json loaded_info = nullptr;
json progress = nullptr;
};
// update the status of a model instance (thread-safe)
// also send SSE notification to /models/sse endpoint
void update_status(const std::string & name, const update_status_args & args);
void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true);
@@ -208,6 +211,9 @@ public:
};
struct server_child {
// serializes the notify_to_router writes
std::mutex mtx_stdout;
// return true if the current process is a child server instance
bool is_child();
+3
View File
@@ -14,6 +14,9 @@ std::vector<std::unique_ptr<field>> make_llama_cmpl_schema(const common_params &
fields.emplace_back(f);
};
add((new field_bool("verbose", params.verbose))
->set_desc("Include __verbose field in the response with additional debug information"));
add((new field_bool("timings_per_token", params.timings_per_token))
->set_desc("Include prompt processing and text generation speed information in each response"));
+1
View File
@@ -11,6 +11,7 @@
#include <cstring>
#include <climits>
#include <algorithm>
#include <unordered_set>
namespace fs = std::filesystem;
@@ -603,3 +603,23 @@ def test_chat_completions_token_count():
})
assert res.status_code == 200
assert res.body["input_tokens"] > 5
def test_verbose_debug():
global server
server.start()
for verbose in [True, False]:
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 2,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
"verbose": verbose,
})
assert res.status_code == 200
if verbose:
assert "__verbose" in res.body
assert "Book" in res.body["__verbose"]["prompt"]
else:
assert "__verbose" not in res.body