llama: avoid copying logits during prompt decode in MTP (#23198)

* llama: avoid copying logits during prompt decode in MTP

* review: update comment

* llama-graph: call set_output for t_h_pre_norm
This commit is contained in:
Aman Gupta
2026-05-17 23:30:25 +08:00
committed by GitHub
parent 39cf5d6191
commit 3e12fbdea5
10 changed files with 91 additions and 27 deletions
+24 -3
View File
@@ -146,8 +146,11 @@ struct common_speculative_impl {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted) = 0;
// true if this implementation requires the target context to extract embeddings
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;
// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_pre_norm() const { return false; }
};
struct common_speculative_impl_draft_simple : public common_speculative_impl {
@@ -429,8 +432,8 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
llama_set_embeddings_pre_norm(ctx_tgt, true);
llama_set_embeddings_pre_norm(ctx_dft, true);
llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
@@ -691,6 +694,10 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
}
bool need_embd() const override {
return false;
}
bool need_embd_pre_norm() const override {
return true;
}
};
@@ -1408,6 +1415,20 @@ bool common_speculative_need_embd(common_speculative * spec) {
return false;
}
bool common_speculative_need_embd_pre_norm(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd_pre_norm()) {
return true;
}
}
return false;
}
void common_speculative_draft(common_speculative * spec) {
if (spec == nullptr) {
return;
+4 -1
View File
@@ -53,9 +53,12 @@ void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, co
// process the batch and update the internal state of the speculative context
bool common_speculative_process(common_speculative * spec, const llama_batch & batch);
// true if any implementation requires target embeddings to be extracted
// true if any implementation requires target post-norm embeddings to be extracted
bool common_speculative_need_embd(common_speculative * spec);
// true if any implementation requires target pre-norm embeddings to be extracted
bool common_speculative_need_embd_pre_norm(common_speculative * spec);
// generate drafts for the sequences specified with `common_speculative_get_draft_params`
void common_speculative_draft(common_speculative * spec);