mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
graph: guard iswa kq_mask on its own buffer (#24294)
A SWA-only draft head (e.g. StepFun MTP) leaves the base sub-cache empty, so its kq_mask buffer stays null and asserts at load. Guard each mask on its own buffer in set_input and can_reuse, base and swa. Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
+13
-4
@@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
@@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
@@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
@@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user