mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
qwen35: use post-norm hidden state for MTP (#24025)
* qwen35: use post-norm hidden state for MTP * rename pre_norm to nextn * fix step35
This commit is contained in:
+10
-10
@@ -3,7 +3,7 @@
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
|
||||
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
|
||||
#include "log.h"
|
||||
#include "ngram-cache.h"
|
||||
#include "ngram-map.h"
|
||||
@@ -162,7 +162,7 @@ struct common_speculative_impl {
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
// true if this implementation requires the target context to extract pre-norm embeddings
|
||||
virtual bool need_embd_pre_norm() const { return false; }
|
||||
virtual bool need_embd_nextn() const { return false; }
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
@@ -487,8 +487,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
}
|
||||
}
|
||||
|
||||
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
|
||||
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
@@ -583,7 +583,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
// ^--- this is a problem
|
||||
// TODO:this is generally true, but would be nice to assert it
|
||||
{
|
||||
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
|
||||
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
|
||||
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
|
||||
|
||||
//{
|
||||
@@ -625,7 +625,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
verify_h[seq_id].resize((size_t) n_rows * n_embd);
|
||||
|
||||
for (int32_t i = 0; i < n_rows; ++i) {
|
||||
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
|
||||
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
|
||||
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
|
||||
}
|
||||
|
||||
@@ -686,7 +686,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
|
||||
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
@@ -772,7 +772,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const override {
|
||||
bool need_embd_nextn() const override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
@@ -1539,13 +1539,13 @@ bool common_speculative_need_embd(common_speculative * spec) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
|
||||
bool common_speculative_need_embd_nextn(common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->need_embd_pre_norm()) {
|
||||
if (impl->need_embd_nextn()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,8 +59,8 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
|
||||
// true if any implementation requires target post-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd(common_speculative * spec);
|
||||
|
||||
// true if any implementation requires target pre-norm embeddings to be extracted
|
||||
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
|
||||
// true if any implementation requires target nextn embeddings to be extracted
|
||||
bool common_speculative_need_embd_nextn(common_speculative * spec);
|
||||
|
||||
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
|
||||
void common_speculative_draft(common_speculative * spec);
|
||||
|
||||
+71
-70
@@ -58,19 +58,20 @@ llama_context::llama_context(
|
||||
cparams.n_rs_seq = 0;
|
||||
}
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.embeddings_pre_norm = false;
|
||||
cparams.embeddings_pre_norm_masked = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.warmup = false;
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||
cparams.embeddings = params.embeddings;
|
||||
cparams.embeddings_nextn = false;
|
||||
cparams.embeddings_nextn_masked = false;
|
||||
cparams.offload_kqv = params.offload_kqv;
|
||||
cparams.no_perf = params.no_perf;
|
||||
cparams.pooling_type = params.pooling_type;
|
||||
cparams.warmup = false;
|
||||
|
||||
|
||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||
@@ -889,34 +890,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_pre_norm() {
|
||||
float * llama_context::get_embeddings_nextn() {
|
||||
output_reorder();
|
||||
|
||||
return embd_pre_norm.data;
|
||||
return embd_nextn.data;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
|
||||
float * llama_context::get_embeddings_nextn_ith(int32_t i) {
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
if (embd_pre_norm.data == nullptr) {
|
||||
throw std::runtime_error("no pre-norm embeddings");
|
||||
if (embd_nextn.data == nullptr) {
|
||||
throw std::runtime_error("no nextn embeddings");
|
||||
}
|
||||
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
|
||||
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
|
||||
if (!cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn rows are stored densely, indexed by raw token position.
|
||||
if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd));
|
||||
}
|
||||
return embd_pre_norm.data + (size_t) i * n_embd;
|
||||
return embd_nextn.data + (size_t) i * n_embd;
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
return embd_pre_norm.data + j*n_embd;
|
||||
return embd_nextn.data + j*n_embd;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
GGML_ABORT("fatal error");
|
||||
#else
|
||||
@@ -1105,11 +1106,11 @@ void llama_context::set_embeddings(bool value) {
|
||||
//sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
|
||||
void llama_context::set_embeddings_nextn(bool value, bool masked) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
|
||||
|
||||
cparams.embeddings_pre_norm = value;
|
||||
cparams.embeddings_pre_norm_masked = masked;
|
||||
cparams.embeddings_nextn = value;
|
||||
cparams.embeddings_nextn_masked = masked;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
@@ -1326,7 +1327,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
||||
}
|
||||
|
||||
int llama_context::encode(const llama_batch & batch_inp) {
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
|
||||
// so accept either present rather than requiring exactly one.
|
||||
GGML_ASSERT(batch_inp.token || batch_inp.embd);
|
||||
|
||||
@@ -1399,9 +1400,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
|
||||
|
||||
// extract logits
|
||||
if (logits.data && t_logits) {
|
||||
@@ -1467,14 +1468,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
// extract nextn embeddings (hidden state before the final output norm)
|
||||
if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
|
||||
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
// TODO: hacky solution
|
||||
@@ -1629,7 +1630,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
|
||||
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
|
||||
// so accept either present rather than requiring exactly one.
|
||||
GGML_ASSERT(batch_inp.token || batch_inp.embd);
|
||||
|
||||
@@ -1829,9 +1830,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
@@ -1912,22 +1913,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
// extract pre-norm embeddings (hidden state before the final output norm)
|
||||
// extract nextn embeddings before
|
||||
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
|
||||
{
|
||||
const bool masked = cparams.embeddings_pre_norm_masked;
|
||||
const bool masked = cparams.embeddings_nextn_masked;
|
||||
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
|
||||
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
|
||||
|
||||
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
|
||||
if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
|
||||
GGML_ASSERT(backend_h != nullptr);
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
|
||||
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
|
||||
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
|
||||
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2019,9 +2020,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_embd_out = hparams.n_embd_out();
|
||||
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
bool has_embd_nextn = cparams.embeddings_nextn;
|
||||
|
||||
// TODO: hacky enc-dec support
|
||||
if (model.arch == LLM_ARCH_T5) {
|
||||
@@ -2033,14 +2034,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
size_t backend_float_count = 0;
|
||||
size_t backend_token_count = 0;
|
||||
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
|
||||
|
||||
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
|
||||
// unmasked: pre-norm row exists for every token in the batch, not just
|
||||
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
|
||||
// unmasked: nextn row exists for every token in the batch, not just
|
||||
// those flagged via batch.logits[i] -> size by token count instead.
|
||||
embd_pre_norm.size = (size_t) n_embd * n_batch;
|
||||
embd_nextn.size = (size_t) n_embd * n_batch;
|
||||
}
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
@@ -2057,7 +2058,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
||||
const size_t new_size =
|
||||
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
|
||||
(logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) +
|
||||
( backend_token_count) * sizeof(llama_token);
|
||||
|
||||
// alloc only when more than the current capacity is required
|
||||
@@ -2074,7 +2075,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
buf_output = nullptr;
|
||||
logits.data = nullptr;
|
||||
embd.data = nullptr;
|
||||
embd_pre_norm.data = nullptr;
|
||||
embd_nextn.data = nullptr;
|
||||
}
|
||||
|
||||
auto * buft = ggml_backend_cpu_buffer_type();
|
||||
@@ -2103,8 +2104,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd.size * sizeof(float);
|
||||
|
||||
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd_pre_norm.size * sizeof(float);
|
||||
embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd_nextn.size * sizeof(float);
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
@@ -2172,9 +2173,9 @@ void llama_context::output_reorder() {
|
||||
}
|
||||
}
|
||||
|
||||
if (embd_pre_norm.size > 0) {
|
||||
if (embd_nextn.size > 0) {
|
||||
for (uint64_t k = 0; k < n_embd; k++) {
|
||||
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
|
||||
std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3588,20 +3589,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
||||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_pre_norm(value, masked);
|
||||
void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
|
||||
ctx->set_embeddings_nextn(value, masked);
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
|
||||
float * llama_get_embeddings_nextn(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_pre_norm();
|
||||
return ctx->get_embeddings_nextn();
|
||||
}
|
||||
|
||||
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
|
||||
float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_pre_norm_ith(i);
|
||||
return ctx->get_embeddings_nextn_ith(i);
|
||||
}
|
||||
|
||||
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
|
||||
+7
-7
@@ -84,8 +84,8 @@ struct llama_context {
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
float * get_embeddings_pre_norm();
|
||||
float * get_embeddings_pre_norm_ith(int32_t i);
|
||||
float * get_embeddings_nextn();
|
||||
float * get_embeddings_nextn_ith(int32_t i);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
@@ -110,7 +110,7 @@ struct llama_context {
|
||||
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
||||
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_pre_norm(bool value, bool masked);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
@@ -282,10 +282,10 @@ private:
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
buffer_view<float> embd = {nullptr, 0};
|
||||
|
||||
// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
|
||||
// sets llm_graph_result::t_h_pre_norm
|
||||
buffer_view<float> embd_pre_norm = {nullptr, 0};
|
||||
// hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when cparams.embeddings_nextn is enabled and the model graph
|
||||
// sets llm_graph_result::t_h_nextn
|
||||
buffer_view<float> embd_nextn = {nullptr, 0};
|
||||
|
||||
struct sampling_info {
|
||||
// !samplers.empty() to check if any samplers are active
|
||||
|
||||
+2
-2
@@ -29,8 +29,8 @@ struct llama_cparams {
|
||||
float yarn_beta_slow;
|
||||
|
||||
bool embeddings;
|
||||
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
|
||||
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
|
||||
bool embeddings_nextn; // also extract the hidden state before the final output norm
|
||||
bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0
|
||||
bool causal_attn;
|
||||
bool offload_kqv;
|
||||
bool flash_attn;
|
||||
|
||||
+4
-8
@@ -89,18 +89,14 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m
|
||||
|
||||
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// pre-norm embeddings (hidden state before the final output norm)
|
||||
//
|
||||
|
||||
// Set whether the context outputs pre-norm embeddings or not
|
||||
// Set whether the context outputs nextn embeddings or not
|
||||
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
|
||||
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
+2
-2
@@ -929,8 +929,8 @@ void llm_graph_result::set_outputs() {
|
||||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
if (t_h_pre_norm != nullptr) {
|
||||
ggml_set_output(t_h_pre_norm);
|
||||
if (t_h_nextn != nullptr) {
|
||||
ggml_set_output(t_h_nextn);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
|
||||
+2
-2
@@ -703,7 +703,7 @@ public:
|
||||
ggml_tensor * get_logits() const { return t_logits; }
|
||||
ggml_tensor * get_embd() const { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
||||
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
|
||||
ggml_tensor * get_h_nextn() const { return t_h_nextn; }
|
||||
|
||||
ggml_cgraph * get_gf() const { return gf; }
|
||||
ggml_context * get_ctx() const { return ctx_compute.get(); }
|
||||
@@ -732,7 +732,7 @@ public:
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
|
||||
ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
|
||||
+11
-14
@@ -177,7 +177,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
@@ -209,16 +209,15 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
@@ -625,18 +624,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
// (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.)
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
|
||||
? layer.nextn.shared_head_norm
|
||||
: model.output_norm;
|
||||
GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm");
|
||||
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cb(cur, "mtp_shared_head_norm", -1);
|
||||
|
||||
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
|
||||
|
||||
+12
-13
@@ -200,7 +200,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
@@ -232,16 +232,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
// post-norm hidden state feeds both the LM head and the MTP seed below
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
@@ -721,17 +721,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
|
||||
? layer.nextn.shared_head_norm
|
||||
: model.output_norm;
|
||||
GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm");
|
||||
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn= cur;
|
||||
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
cb(cur, "mtp_shared_head_norm", -1);
|
||||
|
||||
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
|
||||
|
||||
@@ -294,7 +294,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
cb(cur, "attn_proj", il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
@@ -353,10 +353,10 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
|
||||
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
}
|
||||
|
||||
@@ -541,8 +541,8 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
|
||||
cb(cur, "mtp_post_ffn", il);
|
||||
|
||||
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
|
||||
cb(cur, "h_pre_norm", -1);
|
||||
res->t_h_pre_norm = cur;
|
||||
cb(cur, "h_nextn", -1);
|
||||
res->t_h_nextn = cur;
|
||||
|
||||
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
|
||||
? layer.nextn.shared_head_norm
|
||||
|
||||
@@ -259,9 +259,9 @@ struct server_slot {
|
||||
return task->need_embd() || (spec && common_speculative_need_embd(spec));
|
||||
}
|
||||
|
||||
bool need_embd_pre_norm() const {
|
||||
bool need_embd_nextn() const {
|
||||
GGML_ASSERT(task);
|
||||
return spec && common_speculative_need_embd_pre_norm(spec);
|
||||
return spec && common_speculative_need_embd_nextn(spec);
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
@@ -3013,7 +3013,7 @@ private:
|
||||
|
||||
// embedding requires all tokens in the batch to be output;
|
||||
// MTP also wants logits at every prompt position so the
|
||||
// streaming hook can mirror t_h_pre_norm into ctx_dft.
|
||||
// streaming hook can mirror t_h_nextn into ctx_dft.
|
||||
common_batch_add(batch,
|
||||
cur_tok,
|
||||
slot.prompt.tokens.pos_next(),
|
||||
|
||||
Reference in New Issue
Block a user