mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
kv-cache : avoid kv cells copies (#24277)
This commit is contained in:
@@ -95,13 +95,16 @@ llama_kv_cache::llama_kv_cache(
|
||||
const layer_reuse_cb & reuse,
|
||||
const layer_share_cb & share) :
|
||||
model(model), hparams(hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type),
|
||||
other(static_cast<llama_kv_cache *>(mem_other)),
|
||||
v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()),
|
||||
v_cells(*v_cells_impl) {
|
||||
|
||||
// shared cells view the source cache's K/V tensors, so the cell count
|
||||
// follows the source allocation: a fitted target can be smaller than the
|
||||
// draft default and oversized views would overflow the source tensors
|
||||
if (mem_other) {
|
||||
const uint32_t size_other = static_cast<llama_kv_cache *>(mem_other)->get_size();
|
||||
if (other) {
|
||||
const uint32_t size_other = other->get_size();
|
||||
if (kv_size != size_other) {
|
||||
LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other);
|
||||
kv_size = size_other;
|
||||
@@ -173,8 +176,6 @@ llama_kv_cache::llama_kv_cache(
|
||||
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
other = static_cast<llama_kv_cache *>(mem_other);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; il++) {
|
||||
if (!hparams.has_kv(il)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
|
||||
@@ -1105,7 +1106,6 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
|
||||
if (other) {
|
||||
v_cells = other->v_cells;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -269,7 +269,9 @@ private:
|
||||
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
|
||||
llama_kv_cache * other;
|
||||
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
std::shared_ptr<llama_kv_cells_vec> v_cells_impl;
|
||||
|
||||
llama_kv_cells_vec & v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
||||
@@ -531,3 +531,5 @@ private:
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using llama_kv_cells_vec = std::vector<llama_kv_cells>;
|
||||
|
||||
Reference in New Issue
Block a user