graph: guard iswa kq_mask on its own buffer (#24294)

A SWA-only draft head (e.g. StepFun MTP) leaves the base sub-cache
empty, so its kq_mask buffer stays null and asserts at load. Guard
each mask on its own buffer in set_input and can_reuse, base and swa.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Pascal
2026-06-08 19:20:28 +02:00
committed by GitHub
parent 1705d434f6
commit a66d50588b
+13 -4
View File
@@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
if (self_kq_mask && self_kq_mask->buffer) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
@@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
@@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
if (self_kq_mask && self_kq_mask->buffer) {
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
}
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
@@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
}
return res;
}