From 04eb4c446d22b63449d5dc41c038987d4d8cc3a6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 7 Jun 2026 20:50:54 +0800 Subject: [PATCH] llama : add Gemma4 MTP (#23398) --- common/speculative.cpp | 93 +++++++------- conversion/__init__.py | 2 + conversion/gemma.py | 10 ++ gguf-py/gguf/constants.py | 24 ++++ gguf-py/gguf/tensor_mapping.py | 8 ++ include/llama.h | 4 + src/llama-arch.cpp | 5 + src/llama-arch.h | 3 + src/llama-context.cpp | 55 ++++++--- src/llama-context.h | 3 +- src/llama-cparams.h | 2 + src/llama-ext.h | 2 + src/llama-graph.cpp | 31 +++-- src/llama-graph.h | 1 + src/llama-hparams.cpp | 4 + src/llama-hparams.h | 4 + src/llama-kv-cache-dsa.cpp | 4 +- src/llama-kv-cache-iswa.cpp | 18 ++- src/llama-kv-cache-iswa.h | 4 +- src/llama-kv-cache.cpp | 134 +++++++++++++++++---- src/llama-kv-cache.h | 9 +- src/llama-memory-hybrid-iswa.cpp | 2 + src/llama-memory-hybrid.cpp | 2 + src/llama-memory.h | 4 + src/llama-model.cpp | 87 ++++++++++---- src/llama-model.h | 5 + src/models/gemma4-assistant.cpp | 200 +++++++++++++++++++++++++++++++ src/models/gemma4.cpp | 22 +++- src/models/models.h | 13 ++ tests/test-llama-archs.cpp | 6 +- tools/server/server-context.cpp | 25 ++-- 31 files changed, 644 insertions(+), 142 deletions(-) create mode 100644 src/models/gemma4-assistant.cpp diff --git a/common/speculative.cpp b/common/speculative.cpp index 628ded45ca..86c1e6a429 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -3,13 +3,14 @@ #include "common.h" #include "ggml.h" #include "llama.h" -#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP) #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" #include "ngram-mod.h" #include "sampling.h" +#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP) + #include #include #include @@ -418,6 +419,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; + bool is_mem_shared = false; + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. @@ -444,7 +447,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); - n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); + GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && + "MTP input row width must match the target h_nextn width"); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); @@ -490,6 +495,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); + is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt; + pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); i_batch_beg.assign(n_seq, -1); @@ -526,9 +533,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (N <= 0) { return; } + auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max < N - 1) { + + if (pos_max < N - 1 && !is_mem_shared) { LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " @@ -571,48 +580,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { const size_t row_bytes = (size_t) n_embd * sizeof(float); - common_batch_clear(batch); + // if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode + if (!is_mem_shared) { + common_batch_clear(batch); - for (int k = 0; k < n_tokens; ++k) { - common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); - } - - // shift the tgt embeddings to the right by one position - // assumes that the tokens in the batch are sequential for each sequence - // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] - // ^--- this is a problem - // TODO:this is generally true, but would be nice to assert it - { - const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt); - std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); - - //{ - // // string with seq_ids in the batch - // std::stringstream ss; - // for (int i = 0; i < n_tokens; ++i) { - // ss << batch_in.seq_id[i][0] << ","; - // } - // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); - //} - } - - // fill the pending embeddings from a previous run - auto set_h = [&](int idx, const float * h_row) { - std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); - }; - - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { - if (i_batch_beg[seq_id] < 0) { - continue; + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); } - set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); - } + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + } - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); - return false; + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; + + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } } for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -721,7 +724,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { continue; } - common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + if (is_mem_shared) { + // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens + // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37 + common_batch_add(batch, id, dp.n_past, { seq_id }, true); + } else { + common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + } std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } diff --git a/conversion/__init__.py b/conversion/__init__.py index c670798fc2..18162976f4 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -75,9 +75,11 @@ TEXT_MODEL_MAP: dict[str, str] = { "Gemma3TextModel": "gemma", "Gemma3nForCausalLM": "gemma", "Gemma3nForConditionalGeneration": "gemma", + "Gemma4AssistantForCausalLM": "gemma", "Gemma4ForConditionalGeneration": "gemma", "Gemma4ForCausalLM": "gemma", "Gemma4UnifiedForConditionalGeneration": "gemma", + "Gemma4UnifiedAssistantForCausalLM": "gemma", "GemmaForCausalLM": "gemma", "Glm4ForCausalLM": "glm", "Glm4MoeForCausalLM": "glm", diff --git a/conversion/gemma.py b/conversion/gemma.py index 1258428b04..d8cf8be575 100644 --- a/conversion/gemma.py +++ b/conversion/gemma.py @@ -785,6 +785,16 @@ class Gemma4UnifiedModel(Gemma4Model): self.gguf_writer.add_suppress_tokens(suppress_tokens) +@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM") +class Gemma4AssistantModel(Gemma4Model): + model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"]) + self.gguf_writer.add_nextn_predict_layers(self.block_count) + + @ModelBase.register("Gemma4ForConditionalGeneration") class Gemma4VisionAudioModel(MmprojModel): has_audio_encoder = True diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 814980ce50..bd6246137b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -440,6 +440,7 @@ class MODEL_ARCH(IntEnum): GEMMA3 = auto() GEMMA3N = auto() GEMMA4 = auto() + GEMMA4_ASSISTANT = auto() GEMMA_EMBEDDING = auto() STARCODER2 = auto() RWKV6 = auto() @@ -897,6 +898,8 @@ class MODEL_TENSOR(IntEnum): A_PER_DIM_K_SCALE = auto() # gemma4 A_PER_DIM_SCALE = auto() # gemma4 # nextn/mtp + NEXTN_PROJ_PRE = auto() + NEXTN_PROJ_POST = auto() NEXTN_EH_PROJ = auto() NEXTN_EMBED_TOKENS = auto() NEXTN_ENORM = auto() @@ -986,6 +989,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA4: "gemma4", + MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant", MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", @@ -1471,6 +1475,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down", MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm", # NextN/MTP + MODEL_TENSOR.NEXTN_PROJ_PRE: "nextn.pre_projection", + MODEL_TENSOR.NEXTN_PROJ_POST: "nextn.post_projection", MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm", @@ -2577,6 +2583,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.PER_LAYER_PROJ_NORM, MODEL_TENSOR.PER_LAYER_POST_NORM, ], + MODEL_ARCH.GEMMA4_ASSISTANT: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.NEXTN_PROJ_PRE, + MODEL_TENSOR.NEXTN_PROJ_POST, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.LAYER_OUT_SCALE, + ], MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 3e63b21650..a9537983de 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -2367,6 +2367,14 @@ class TensorNameMap: ), # NextN/MTP tensors + MODEL_TENSOR.NEXTN_PROJ_PRE: ( + "pre_projection", + ), + + MODEL_TENSOR.NEXTN_PROJ_POST: ( + "post_projection", + ), + MODEL_TENSOR.NEXTN_EH_PROJ: ( "model.layers.{bid}.eh_proj", ), diff --git a/include/llama.h b/include/llama.h index 9f78aa9a05..27e4806742 100644 --- a/include/llama.h +++ b/include/llama.h @@ -388,6 +388,10 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // a source/target/parent context + // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts + struct llama_context * ctx_other; }; struct llama_model_tensor_override { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 52963f8f1e..6a5d5f8d2a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -57,6 +57,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -453,6 +454,8 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -765,6 +768,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. diff --git a/src/llama-arch.h b/src/llama-arch.h index dc9bca9bfc..03b1a265d6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -61,6 +61,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -557,6 +558,8 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d0c314199b..9a40c4366a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -69,9 +69,10 @@ llama_context::llama_context( cparams.embeddings_nextn_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; cparams.warmup = false; + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -84,7 +85,17 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.ctx_type = params.ctx_type; + cparams.ctx_other = nullptr; + + // TODO: more generic + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + // TODO: change from runtime_error to llama_exception to avoid printing error message + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later @@ -300,10 +311,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), }; memory.reset(model.create_memory(params_mem, cparams)); @@ -904,7 +916,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) { throw std::runtime_error("no nextn embeddings"); } - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); if (!cparams.embeddings_nextn_masked) { // unmasked: nextn rows are stored densely, indexed by raw token position. @@ -1473,7 +1485,7 @@ int llama_context::encode(const llama_batch & batch_inp) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size); ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float)); } @@ -1924,7 +1936,7 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); float * embd_nextn_out = embd_nextn.data + offset*n_embd; GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size); @@ -2017,7 +2029,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; @@ -2036,12 +2047,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0; + embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0; if (has_embd_nextn && !cparams.embeddings_nextn_masked) { // unmasked: nextn row exists for every token in the batch, not just // those flagged via batch.logits[i] -> size by token count instead. - embd_nextn.size = (size_t) n_embd * n_batch; + embd_nextn.size = (size_t) n_embd_out * n_batch; } // Allocate backend sampling output buffers if there are backend samplers configured. @@ -3375,6 +3386,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -3454,7 +3466,6 @@ llama_context * llama_init_from_model( return nullptr; } - try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3593,6 +3604,14 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { ctx->set_embeddings_nextn(value, masked); } +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + if (!ctx) { + return nullptr; + } + + return ctx->get_memory(); +} + float * llama_get_embeddings_nextn(llama_context * ctx) { ctx->synchronize(); @@ -3656,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { - auto * memory = ctx->get_memory(); + auto memory = ctx->get_memory(); llama_memory_context_ptr mctx; if (memory) { mctx = memory->init_full(); @@ -3696,10 +3715,6 @@ int32_t llama_set_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; @@ -4010,3 +4025,7 @@ void llama_opt_epoch( llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { return ctx->memory_breakdown(); } + +llama_context * llama_get_ctx_other(struct llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} diff --git a/src/llama-context.h b/src/llama-context.h index 2af92b0f09..6f8f59a22a 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -6,6 +6,7 @@ #include "llama-graph.h" #include "llama-adapter.h" #include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -273,7 +274,7 @@ private: llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr memory; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view logits = {nullptr, 0}; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fd227ee5a2..8a35d389ef 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -49,4 +49,6 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_other; }; diff --git a/src/llama-ext.h b/src/llama-ext.h index 7ad6125fad..bd74544129 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -100,3 +100,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 3b8125cde7..da7a929556 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -565,18 +565,18 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); } @@ -605,18 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs && self_k_idxs->buffer) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + return res; } @@ -756,7 +756,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + } + if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) { attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); } @@ -764,7 +766,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + } + if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } @@ -810,18 +814,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; @@ -1006,6 +1010,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : ubatch (params.ubatch), n_embd (hparams.n_embd), n_layer (hparams.n_layer()), + n_layer_nextn (hparams.n_layer_nextn), n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), diff --git a/src/llama-graph.h b/src/llama-graph.h index bf5be09ac7..6793846e3e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -784,6 +784,7 @@ struct llm_graph_context { const int64_t n_embd; const int64_t n_layer; + const int64_t n_layer_nextn; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_head; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index e1e49d1cc1..2bf5768738 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -91,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 87db4a0dd3..032944cb48 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -185,6 +185,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // input embedding dimension (0 = use n_embd) + uint32_t n_embd_inp_impl = 0; + // output embedding dimension (0 = use n_embd) uint32_t n_embd_out_impl = 0; @@ -224,6 +227,7 @@ struct llama_hparams { // complex mapping. If using deepstack_mapping_arr, also make sure to set // n_deepstack_layers to the number of unique deepstack layers so that // n_embd_imp is accurate (see granite.cpp). + // TODO: can be expressed via the `new n_embd_inp_impl` and remove this param uint32_t n_deepstack_layers = 0; // deepstack layer array (Granite4 Vision) diff --git a/src/llama-kv-cache-dsa.cpp b/src/llama-kv-cache-dsa.cpp index e44004b558..916ab65375 100644 --- a/src/llama-kv-cache-dsa.cpp +++ b/src/llama-kv-cache-dsa.cpp @@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa( kv_mla = std::make_unique( model, model.hparams, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, - n_swa, swa_type, filter, reuse); + n_swa, swa_type, nullptr, filter, reuse, nullptr); // we use llama_kv_cache for caching indexer keys // by hand-tweaking some hparams we fool it to create @@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa( kv_lid = std::make_unique( model, hparams_lid, type_k, type_v, v_trans, offload, unified, kv_size, n_seq_max, n_pad, - n_swa, swa_type, filter, reuse); + n_swa, swa_type, nullptr, filter, reuse, nullptr); } void llama_kv_cache_dsa::clear(bool data) { diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 9b9f179036..aa1b1b72eb 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + if (mem_other) { + mem_other_base = static_cast(mem_other)->get_base(); + } + + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_swa = static_cast(mem_other)->get_swa(); + } + kv_base = std::make_unique( model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 70ab22f0d6..dfafc1ef51 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ public: uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 60ae42e378..cad7eb984c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -90,8 +90,10 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : + const layer_reuse_cb & reuse, + const layer_share_cb & share) : model(model), hparams(hparams), v_trans(v_trans), n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { @@ -160,6 +162,8 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); + other = static_cast(mem_other); + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -171,6 +175,24 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -282,29 +304,38 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); - const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; - if (attn_rot_disable) { - LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + n_embd_head_k_all = other->n_embd_head_k_all; + n_embd_head_v_all = other->n_embd_head_v_all; + + attn_rot_k = other->attn_rot_k; + attn_rot_v = other->attn_rot_v; + } else { + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; + + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } + + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; } - attn_rot_k = - !attn_rot_disable && - n_embd_head_k_all > 0 && - ggml_is_quantized(type_k) && - hparams.n_embd_head_k() % 64 == 0; - - // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer - if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { - attn_rot_k = true; - } - - attn_rot_v = - !attn_rot_disable && - n_embd_head_v_all > 0 && - ggml_is_quantized(type_v) && - hparams.n_embd_head_v() % 64 == 0; - LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); @@ -347,6 +378,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -410,6 +446,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -497,6 +538,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -519,6 +565,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -564,6 +615,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -598,6 +654,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_min(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -606,6 +667,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { } llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_max(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -746,6 +812,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vectorget_sched(); @@ -1021,6 +1092,12 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + v_cells = other->v_cells; + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1815,6 +1892,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1860,6 +1940,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1925,6 +2010,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 649269af6d..f5ace6ae35 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -98,7 +98,7 @@ public: // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, - const llama_hparams & hparams, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -109,8 +109,10 @@ public: uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -264,6 +266,9 @@ private: // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector v_heads; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + std::vector v_cells; // maps from a sequence id to a stream id diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index a242079b40..c7d4bcd413 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 66ec3fd6d5..f2d49cbce5 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( diff --git a/src/llama-memory.h b/src/llama-memory.h index 4ad1612e45..db82539664 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -23,6 +23,8 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + llama_memory_t mem_other; }; enum llama_memory_status { @@ -76,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function; + using layer_share_cb = std::function; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0d23a605ee..4f12e0949a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -1717,19 +1719,21 @@ void llama_model::print_info() const { if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out()); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer()); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer()).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer()).c_str()); + LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer()).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer()).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer()).c_str()); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); @@ -1737,7 +1741,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer()).c_str()); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); @@ -1764,7 +1768,7 @@ void llama_model::print_info() const { [](const auto & entry) { return entry >= 0; })) { LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__, print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; }, - hparams.n_layer()).c_str()); + hparams.n_layer_all).c_str()); } // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { @@ -2113,8 +2117,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_recr */ std::move(filter_recr)); } } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](uint32_t il) { @@ -2143,20 +2148,53 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - filter, - reuse); + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + llama_memory_t mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_other) - 2; + } + + return llama_model_n_layer(model_other) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_other, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + filter, + reuse, + share); + } } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2173,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + nullptr, filter, + nullptr, nullptr); } } @@ -2406,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: diff --git a/src/llama-model.h b/src/llama-model.h index 884cfdf5c3..992c8d9c8f 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -548,6 +548,10 @@ struct llama_model { struct ggml_tensor * output_s = nullptr; struct ggml_tensor * output_in_s = nullptr; + // NextN/MTP model-level projections + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; @@ -702,6 +706,7 @@ const char * llm_type_name(llm_type type); #define LLAMA_LOAD_LOCALS \ const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ diff --git a/src/models/gemma4-assistant.cpp b/src/models/gemma4-assistant.cpp new file mode 100644 index 0000000000..5b7a25a5ab --- /dev/null +++ b/src/models/gemma4-assistant.cpp @@ -0,0 +1,200 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl"); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 7198e54111..6f7fcd645c 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -155,12 +155,14 @@ public: } virtual ~llm_graph_input_logits_bias() = default; - void set_input(const llama_ubatch *) override { + void set_input(const llama_ubatch * /*ubatch*/) override { const int64_t n_vocab = arr.size(); ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias)); } - // bool can_reuse(const llm_graph_params & params) override; + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } ggml_tensor * logits_bias = nullptr; // F32 [n_vocab] @@ -270,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing - if (il == n_layer - 1 && inp_out_ids) { + // keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token) + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -370,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] // TODO @ngxson : improve this - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); } @@ -401,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para model.output_norm, nullptr, LLM_NORM_RMS, -1); + // Expose the post-output-norm hidden state (the LM-head input feature) so that + // MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the + // recurrent h input. This matches the reference (transformers/vLLM/SGLang), + // which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cb(cur, "result_norm", -1); res->t_embd = cur; diff --git a/src/models/models.h b/src/models/models.h index 866e0d0be3..c137e32e8f 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 4fe585e29a..8037a11398 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -392,7 +392,7 @@ static bool arch_supported(const llm_arch arch) { if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { return false; // FIXME CUDA backend crashes. } - if (arch == LLM_ARCH_GEMMA4) { + if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) { return false; // FIXME @ngxson } if (arch == LLM_ARCH_LLAMA_EMBED || arch == LLM_ARCH_GEMMA_EMBEDDING || arch == LLM_ARCH_T5ENCODER) { @@ -447,7 +447,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) { continue; } - if (arch == LLM_ARCH_GEMMA4) { + if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) { continue; // FIXME: ISWA KV cache initialization needs more fixture params } for (bool moe : {false, true}) { @@ -550,7 +550,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) { continue; } - if (arch == LLM_ARCH_GEMMA4) { + if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) { continue; // FIXME: ISWA KV cache initialization needs more fixture params } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5d546d09c2..07759f4170 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,4 +1,3 @@ - #include "server-context.h" #include "server-chat.h" #include "server-common.h" @@ -16,6 +15,11 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include "ggml-cpp.h" + +// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING] +#include "../../src/llama-ext.h" + #include #include #include @@ -884,7 +888,7 @@ private: has_draft ? "draft model" : "MTP context", total / (1024.0 * 1024.0)); } catch (const std::exception & e) { - SRV_ERR("[spec] failed to measure %s memory: %s\n", + SRV_WRN("[spec] failed to measure %s memory: %s\n", has_draft ? "draft model" : "MTP context", e.what()); } } @@ -940,16 +944,17 @@ private: const bool spec_mtp = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } // note: for small models maybe we can set this to the maximum possible draft from all speculative types // the extra memory for small models is likely negligible? - cparams.n_rs_seq = 0; - ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + cparams.n_rs_seq = 0; + cparams.ctx_other = ctx_tgt; - ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); @@ -964,6 +969,7 @@ private: cparams_mtp.type_v = params_base.speculative.draft.cache_type_v; cparams_mtp.n_rs_seq = 0; cparams_mtp.n_outputs_max = params_base.n_parallel; + cparams_mtp.ctx_other = ctx_tgt; ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp)); if (ctx_dft == nullptr) { @@ -971,8 +977,6 @@ private: return false; } - ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); - params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); } @@ -1060,6 +1064,10 @@ private: } } + if (ctx_dft) { + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); + } + if (spec) { SRV_INF("%s", "speculative decoding context initialized\n"); } else { @@ -2974,10 +2982,11 @@ private: continue; } - if (ctx_dft) { + if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) { // TODO: in the future, figure out how to infuse target embeddings to the images // for now, we skip this for simplicity // maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above? + // [TAG_MTMD_DRAFT_PROCESSING] res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); if (res != 0) { GGML_ABORT("failed to process multi-modal data on draft context\n");