kv-cache : avoid kv cells copies (#24277)

This commit is contained in:
Georgi Gerganov
2026-06-07 21:42:54 +03:00
committed by GitHub
parent f0156d1401
commit 379ac6673b
3 changed files with 11 additions and 7 deletions
+6 -6
View File
@@ -95,13 +95,16 @@ llama_kv_cache::llama_kv_cache(
const layer_reuse_cb & reuse,
const layer_share_cb & share) :
model(model), hparams(hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type),
other(static_cast<llama_kv_cache *>(mem_other)),
v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()),
v_cells(*v_cells_impl) {
// shared cells view the source cache's K/V tensors, so the cell count
// follows the source allocation: a fitted target can be smaller than the
// draft default and oversized views would overflow the source tensors
if (mem_other) {
const uint32_t size_other = static_cast<llama_kv_cache *>(mem_other)->get_size();
if (other) {
const uint32_t size_other = other->get_size();
if (kv_size != size_other) {
LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other);
kv_size = size_other;
@@ -173,8 +176,6 @@ llama_kv_cache::llama_kv_cache(
const bool is_mla = hparams.is_mla();
other = static_cast<llama_kv_cache *>(mem_other);
for (uint32_t il = 0; il < n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -1105,7 +1106,6 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
v_cells = other->v_cells;
return;
}
+3 -1
View File
@@ -269,7 +269,9 @@ private:
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
llama_kv_cache * other;
std::vector<llama_kv_cells> v_cells;
std::shared_ptr<llama_kv_cells_vec> v_cells_impl;
llama_kv_cells_vec & v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
+2
View File
@@ -531,3 +531,5 @@ private:
}
}
};
using llama_kv_cells_vec = std::vector<llama_kv_cells>;