qwen35: use post-norm hidden state for MTP (#24025)

* qwen35: use post-norm hidden state for MTP

* rename pre_norm to nextn

* fix step35
This commit is contained in:
Aman Gupta
2026-06-04 01:29:09 +08:00
committed by GitHub
parent c8d6a00636
commit 166fe29492
12 changed files with 132 additions and 139 deletions
+10 -10
View File
@@ -3,7 +3,7 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_pre_norm / llama_get_embeddings_pre_norm_ith (used by MTP)
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
@@ -162,7 +162,7 @@ struct common_speculative_impl {
virtual bool need_embd() const = 0;
// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_pre_norm() const { return false; }
virtual bool need_embd_nextn() const { return false; }
};
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -487,8 +487,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}
}
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@@ -583,7 +583,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
@@ -625,7 +625,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
verify_h[seq_id].resize((size_t) n_rows * n_embd);
for (int32_t i = 0; i < n_rows; ++i) {
const float * h = llama_get_embeddings_pre_norm_ith(ctx_tgt, i_batch_beg[seq_id] + i);
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
}
@@ -686,7 +686,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_pre_norm_ith(ctx_dft, i_batch);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -772,7 +772,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
return false;
}
bool need_embd_pre_norm() const override {
bool need_embd_nextn() const override {
return true;
}
};
@@ -1539,13 +1539,13 @@ bool common_speculative_need_embd(common_speculative * spec) {
return false;
}
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
bool common_speculative_need_embd_nextn(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd_pre_norm()) {
if (impl->need_embd_nextn()) {
return true;
}
}
+2 -2
View File
@@ -59,8 +59,8 @@ bool common_speculative_process(common_speculative * spec, const llama_batch & b
// true if any implementation requires target post-norm embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);
// true if any implementation requires target pre-norm embeddings to be extracted
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
// true if any implementation requires target nextn embeddings to be extracted
bool common_speculative_need_embd_nextn(common_speculative * spec);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);
+71 -70
View File
@@ -58,19 +58,20 @@ llama_context::llama_context(
cparams.n_rs_seq = 0;
}
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_pre_norm = false;
cparams.embeddings_pre_norm_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
cparams.embeddings = params.embeddings;
cparams.embeddings_nextn = false;
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -889,34 +890,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
float * llama_context::get_embeddings_pre_norm() {
float * llama_context::get_embeddings_nextn() {
output_reorder();
return embd_pre_norm.data;
return embd_nextn.data;
}
float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
float * llama_context::get_embeddings_nextn_ith(int32_t i) {
output_reorder();
try {
if (embd_pre_norm.data == nullptr) {
throw std::runtime_error("no pre-norm embeddings");
if (embd_nextn.data == nullptr) {
throw std::runtime_error("no nextn embeddings");
}
const uint32_t n_embd = model.hparams.n_embd;
if (!cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd));
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) {
throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd));
}
return embd_pre_norm.data + (size_t) i * n_embd;
return embd_nextn.data + (size_t) i * n_embd;
}
const int64_t j = output_resolve_row(i);
return embd_pre_norm.data + j*n_embd;
return embd_nextn.data + j*n_embd;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what());
LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
#else
@@ -1105,11 +1106,11 @@ void llama_context::set_embeddings(bool value) {
//sched_need_reserve = true;
}
void llama_context::set_embeddings_pre_norm(bool value, bool masked) {
void llama_context::set_embeddings_nextn(bool value, bool masked) {
LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked);
cparams.embeddings_pre_norm = value;
cparams.embeddings_pre_norm_masked = masked;
cparams.embeddings_nextn = value;
cparams.embeddings_nextn_masked = masked;
}
void llama_context::set_causal_attn(bool value) {
@@ -1326,7 +1327,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
}
int llama_context::encode(const llama_batch & batch_inp) {
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
@@ -1399,9 +1400,9 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
// extract logits
if (logits.data && t_logits) {
@@ -1467,14 +1468,14 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
// extract nextn embeddings (hidden state before the final output norm)
if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float));
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}
// TODO: hacky solution
@@ -1629,7 +1630,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
}
int llama_context::decode(const llama_batch & batch_inp) {
// MTP hook batches carry both token (next-token id) and embd (h_pre_norm row),
// MTP hook batches carry both token (next-token id) and embd (h_nextn row),
// so accept either present rather than requiring exactly one.
GGML_ASSERT(batch_inp.token || batch_inp.embd);
@@ -1829,9 +1830,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr;
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr;
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
@@ -1912,22 +1913,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
// extract pre-norm embeddings (hidden state before the final output norm)
// extract nextn embeddings before
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
{
const bool masked = cparams.embeddings_pre_norm_masked;
const bool masked = cparams.embeddings_nextn_masked;
const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens;
const int64_t offset = masked ? n_outputs_prev : n_tokens_prev;
if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd;
const uint32_t n_embd = hparams.n_embd;
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float));
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float));
}
}
@@ -2019,9 +2020,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_pre_norm = cparams.embeddings_pre_norm;
bool has_logits = true;
bool has_embd = cparams.embeddings;
bool has_embd_nextn = cparams.embeddings_nextn;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
@@ -2033,14 +2034,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
size_t backend_float_count = 0;
size_t backend_token_count = 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) {
// unmasked: pre-norm row exists for every token in the batch, not just
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_pre_norm.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -2057,7 +2058,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) +
(logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
// alloc only when more than the current capacity is required
@@ -2074,7 +2075,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
buf_output = nullptr;
logits.data = nullptr;
embd.data = nullptr;
embd_pre_norm.data = nullptr;
embd_nextn.data = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
@@ -2103,8 +2104,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0};
offset += embd_pre_norm.size * sizeof(float);
embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0};
offset += embd_nextn.size * sizeof(float);
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
@@ -2172,9 +2173,9 @@ void llama_context::output_reorder() {
}
}
if (embd_pre_norm.size > 0) {
if (embd_nextn.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]);
}
}
@@ -3588,20 +3589,20 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_pre_norm(value, masked);
void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
float * llama_get_embeddings_pre_norm(llama_context * ctx) {
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm();
return ctx->get_embeddings_nextn();
}
float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) {
float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return ctx->get_embeddings_pre_norm_ith(i);
return ctx->get_embeddings_nextn_ith(i);
}
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
+7 -7
View File
@@ -84,8 +84,8 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
float * get_embeddings_pre_norm();
float * get_embeddings_pre_norm_ith(int32_t i);
float * get_embeddings_nextn();
float * get_embeddings_nextn_ith(int32_t i);
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@@ -110,7 +110,7 @@ struct llama_context {
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
void set_embeddings (bool value);
void set_embeddings_pre_norm(bool value, bool masked);
void set_embeddings_nextn(bool value, bool masked);
void set_causal_attn(bool value);
void set_warmup(bool value);
@@ -282,10 +282,10 @@ private:
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
buffer_view<float> embd = {nullptr, 0};
// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
// sets llm_graph_result::t_h_pre_norm
buffer_view<float> embd_pre_norm = {nullptr, 0};
// hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd])
// populated only when cparams.embeddings_nextn is enabled and the model graph
// sets llm_graph_result::t_h_nextn
buffer_view<float> embd_nextn = {nullptr, 0};
struct sampling_info {
// !samplers.empty() to check if any samplers are active
+2 -2
View File
@@ -29,8 +29,8 @@ struct llama_cparams {
float yarn_beta_slow;
bool embeddings;
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
bool embeddings_nextn; // also extract the hidden state before the final output norm
bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0
bool causal_attn;
bool offload_kqv;
bool flash_attn;
+4 -8
View File
@@ -89,18 +89,14 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
//
// pre-norm embeddings (hidden state before the final output norm)
//
// Set whether the context outputs pre-norm embeddings or not
// Set whether the context outputs nextn embeddings or not
// If masked == true, output the embeddings only for the tokens with batch.logits != 0
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked);
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
// mirrors:
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx);
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
+2 -2
View File
@@ -929,8 +929,8 @@ void llm_graph_result::set_outputs() {
if (t_embd_pooled != nullptr) {
ggml_set_output(t_embd_pooled);
}
if (t_h_pre_norm != nullptr) {
ggml_set_output(t_h_pre_norm);
if (t_h_nextn != nullptr) {
ggml_set_output(t_h_nextn);
}
for (auto & [seq_id, t] : t_sampled) {
if (t != nullptr) {
+2 -2
View File
@@ -703,7 +703,7 @@ public:
ggml_tensor * get_logits() const { return t_logits; }
ggml_tensor * get_embd() const { return t_embd; }
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; }
ggml_tensor * get_h_nextn() const { return t_h_nextn; }
ggml_cgraph * get_gf() const { return gf; }
ggml_context * get_ctx() const { return ctx_compute.get(); }
@@ -732,7 +732,7 @@ public:
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm
ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
std::map<llama_seq_id, ggml_tensor*> t_candidates;
+11 -14
View File
@@ -177,7 +177,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -209,16 +209,15 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
@@ -625,18 +624,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
// (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.)
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
+12 -13
View File
@@ -200,7 +200,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -232,16 +232,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
}
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
// post-norm hidden state feeds both the LM head and the MTP seed below
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
@@ -721,17 +721,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm
cur = ggml_add(ctx0, cur, ffn_residual);
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm");
cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1);
cb(cur, "h_nextn", -1);
res->t_h_nextn= cur;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "mtp_shared_head_norm", -1);
ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output;
+6 -6
View File
@@ -294,7 +294,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "attn_proj", il);
}
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) {
if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
@@ -353,10 +353,10 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para
cur = inpL;
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_pre_norm_masked && inp_out_ids) {
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
@@ -541,8 +541,8 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cb(cur, "mtp_post_ffn", il);
// Pre-norm hidden state: used by the AR draft loop to seed the next MTP step.
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
+3 -3
View File
@@ -259,9 +259,9 @@ struct server_slot {
return task->need_embd() || (spec && common_speculative_need_embd(spec));
}
bool need_embd_pre_norm() const {
bool need_embd_nextn() const {
GGML_ASSERT(task);
return spec && common_speculative_need_embd_pre_norm(spec);
return spec && common_speculative_need_embd_nextn(spec);
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
@@ -3013,7 +3013,7 @@ private:
// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_pre_norm into ctx_dft.
// streaming hook can mirror t_h_nextn into ctx_dft.
common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),