Compare commits

...

5 Commits

Author SHA1 Message Date
Adrien Gallouët ac4cddeb0d vendor : update LibreSSL to 4.3.2 (#24397)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-10 22:28:03 +02:00
Gaurav Garg e95dae18d6 Remove padding and multiple D2D copies for MTP (#24086)
* Make ggml_gated_delta_net take only the initial recurrent state (D, 1, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1].

Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy

* Make GDN changes in all backends. Address review comments.

* Fix CI build errors
2026-06-10 23:21:16 +05:30
Tarek Dakhran d2462f8f7a chat: fix LFM2/LFM2.5 ignoring json_schema (#24377)
The LFM2 specialized template handler only built a grammar for tool-calling,
silently ignoring json_schema from response_format.
2026-06-10 14:41:41 +02:00
Oliver Simons fb83cc9a07 CUDA: Fix ssm_scan_f32 data-races (#24360)
* Add missing syncthreads before resuing cub_temp_storage

__syncthreads() is required before being allowed to resue TempStorage
smem:
https://nvidia.github.io/cccl/unstable/cub/api/classcub_1_1BlockLoad.html#_CPPv4I0EN3cub9BlockLoad4LoadEv20RandomAccessIteratorRA14ItemsPerThread_1Ti

* Add one more missing __syncthreads

Could also double-buffer, but alternative is to simply ensure all
threads have read smem* before writing to it again in the next loop
iteration

* Remove unused smem from ssm_scan_f32
2026-06-10 14:27:08 +02:00
Sigbjørn Skjæret 039e20a2db ci : bump komac version (#24396) 2026-06-10 09:45:20 +02:00
24 changed files with 135 additions and 113 deletions
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.15.0 -y
cargo binstall komac@2.16.0 -y
- name: Find latest release
id: find_latest_release
+12 -3
View File
@@ -1647,11 +1647,12 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
// Gate by reasoning format and whether the template supports <think>
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE &&
tmpl.source().find(THINK_START) != std::string::npos;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
@@ -1674,6 +1675,10 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
if (has_response_format) {
auto response_format = p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema));
return generation_prompt + reasoning + response_format + end;
}
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_calls = p.rule("tool-calls",
@@ -1692,13 +1697,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
+12 -5
View File
@@ -2553,10 +2553,16 @@ extern "C" {
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
//
// state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs):
// K == 1: output carries the final state only.
// K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K)
// per-token snapshots into the trailing slots
// tensor shapes (S_k == S_v, H_v % H_k == 0):
// q, k : [S_k, H_k, n_tokens, n_seqs]
// v : [S_v, H_v, n_tokens, n_seqs]
// g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA)
// beta : [1, H_v, n_tokens, n_seqs]
// state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0
//
// the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state
// snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1
// keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written.
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,
@@ -2564,7 +2570,8 @@ extern "C" {
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state);
struct ggml_tensor * state,
int64_t K);
// custom operators
+2 -2
View File
@@ -776,8 +776,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
// state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0,
// so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2).
// state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
// so a head-aligned split on the input cache lands on axis 2 here.
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
};
+1 -1
View File
@@ -2948,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan(
case GGML_OP_GATED_DELTA_NET:
{
const int64_t S_v = node->src[2]->ne[0];
const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs)
const int64_t K = ggml_get_op_params_i32(node, 0);
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
cur = per_thread * sizeof(float) * n_tasks;
} break;
+8 -9
View File
@@ -10624,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const bool kda = (neg0 == S_v);
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int64_t K = src_state->ne[1];
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
const int64_t K = ggml_get_op_params_i32(dst, 0);
GGML_ASSERT(K >= 1);
// per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride)
const int64_t state_seq_stride = src_state->nb[2] / sizeof(float);
// per-seq stride in floats (seq s starts at state + s * seq_stride)
const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
const int ith = params->ith;
@@ -10644,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
float * attn_out_base = (float *)dst->data;
float * state_out_base = (float *)dst->data + attn_score_elems;
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int64_t shift = n_tokens - K;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
const float * state_in_base = (const float *)src_state->data;
@@ -10674,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
// copy input state into the working buffer and operate in-place
// state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride.
// state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride.
const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
@@ -10727,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
attn_data += S_v * H; // advance to next token
if (K > 1) {
const int64_t target_slot = t - shift;
const int64_t target_slot = n_tokens - 1 - t;
if (target_slot >= 0 && target_slot < K) {
float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
(iv3 * H + iv1) * S_v * S_v;
+7 -9
View File
@@ -39,9 +39,9 @@ gated_delta_net_cuda(const float * q,
float * attn_data = dst;
float * state = dst + attn_score_elems;
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
state += state_out_offset;
curr_state += state_in_offset + col * S_v;
@@ -143,12 +143,10 @@ gated_delta_net_cuda(const float * q,
attn_data += S_v * H;
if constexpr (keep_rs_t) {
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
// are written; earlier slots are left untouched (caller-owned).
const int shift = (int) n_tokens - K;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
const int target_slot = t - shift;
const int target_slot = (int) n_tokens - 1 - t;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
#pragma unroll
@@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
cudaStream_t stream = ctx.stream();
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = (int) src_state->ne[1];
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
const int K = ggml_get_op_params_i32(dst, 0);
const bool keep_rs = K > 1;
if (kda) {
+3 -2
View File
@@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1)
__shared__ CubTempStorage cub_temp_storage;
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
__syncthreads();
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
#else
const int stride_s0 = src0_nb2 / sizeof(float);
@@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1)
regs0[n] = state;
}
y_block[i * stride_y + threadIdx.x] = sumf;
__syncthreads();
}
#ifdef USE_CUB
@@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
switch (n_tok)
{
case 1:
+3 -2
View File
@@ -2538,7 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
const int64_t H = v->ne[1];
const int64_t n_tokens = v->ne[2];
const int64_t n_seqs = v->ne[3];
const int64_t K = state->ne[1];
const int64_t K = ggml_get_op_params_i32(op, 0);
if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) {
return false;
@@ -2551,7 +2551,8 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
return false;
}
if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) {
// state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0.
if (ggml_nelements(state) != S_v * S_v * H * n_seqs) {
return false;
}
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
+16 -13
View File
@@ -584,7 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
const uint32_t H = v->ne[1];
const uint32_t n_tokens = v->ne[2];
const uint32_t n_seqs = v->ne[3];
const uint32_t K = state->ne[1];
const uint32_t K = octx->op_params[0];
const uint32_t total_rows = H * n_seqs;
if (ith >= total_rows) {
@@ -618,9 +618,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
const int64_t shift = (int64_t) n_tokens - (int64_t) K;
uint32_t ir_prefetch = ith;
int spad_idx = 0;
@@ -630,7 +629,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// final state lands in snapshot slot 0 (most-recent-first ordering)
float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// Push dummy write-back
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
@@ -661,7 +661,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
// final state lands in snapshot slot 0 (most-recent-first ordering)
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
@@ -792,7 +793,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
}
if (K > 1) {
const int64_t target_slot = (int64_t) t - shift;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t;
if (target_slot >= 0 && target_slot < (int64_t) K) {
float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
if (curr_state_o != s_out) {
@@ -844,7 +846,6 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
const uint32_t S_v = v->ne[0];
const uint32_t H = v->ne[1];
const uint32_t n_seqs = v->ne[3];
const uint32_t K = state->ne[1];
const uint32_t total_rows = H * n_seqs;
if (ith >= total_rows) {
@@ -878,8 +879,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
uint32_t ir_prefetch = ith;
int spad_idx = 0;
@@ -889,7 +889,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// final state lands in snapshot slot 0 (most-recent-first ordering)
float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
// Push dummy write-back
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
@@ -920,7 +921,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
// final state lands in snapshot slot 0 (most-recent-first ordering)
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
@@ -1097,7 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
const uint32_t H = v->ne[1];
const uint32_t n_tokens = v->ne[2];
const uint32_t n_seqs = v->ne[3];
const uint32_t K = state->ne[1];
const uint32_t K = octx->op_params[0];
if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
return HTP_STATUS_NO_SUPPORT;
@@ -1110,7 +1112,8 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
return HTP_STATUS_NO_SUPPORT;
}
if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
// state holds s0 only: [S_v, S_v, H, n_seqs]
if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) {
return HTP_STATUS_NO_SUPPORT;
}
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
+2 -2
View File
@@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = op->src[5]->ne[1];
// state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0.
const int K = ggml_get_op_params_i32(op, 0);
const int nsg = op->src[2]->ne[0]/32;
+5 -6
View File
@@ -2599,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl(
const float scale = 1.0f / sqrt((float)S_v);
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
@@ -2620,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl(
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int shift = (int)args.ne22 - (int)K;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
@@ -2680,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl(
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)t - shift;
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
+1 -1
View File
@@ -17750,7 +17750,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) {
const cl_uint H_v = (cl_uint) src_v->ne[1];
const cl_uint n_tokens = (cl_uint) src_v->ne[2];
const cl_uint n_seqs = (cl_uint) src_v->ne[3];
const cl_uint K = (cl_uint) src_state->ne[1];
const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0);
int si;
switch (S_v) {
@@ -123,7 +123,8 @@ kernel void kernel_gated_delta_net(
const uint iq3 = seq_id / rq3; // seq index for Q and K
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * K * H_v + head_id) * state_size;
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
const uint state_base = (seq_id * H_v + head_id) * state_size;
const uint q_off_base = iq3 * sq3 + iq1 * sq1;
const uint v_off_base = seq_id * sv3 + head_id * sv1;
const uint gb_off_base = seq_id * sb3 + head_id * sb1;
@@ -143,7 +144,8 @@ kernel void kernel_gated_delta_net(
}
}
const int shift = (int)n_tokens - (int)K;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V;
for (uint t = 0; t < n_tokens; t++) {
@@ -219,7 +221,7 @@ kernel void kernel_gated_delta_net(
attn_off += S_V * H_v;
if (K > 1u) {
const int target_slot = (int)t - shift;
const int target_slot = (int)n_tokens - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
#pragma unroll
for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) {
+7 -8
View File
@@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q,
float * attn_data = dst;
float * state = dst + attn_score_elems;
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
// input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v;
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
state += state_out_offset;
@@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q,
s_shard[r] = curr_state[i];
}
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
// are written; earlier slots are left untouched (caller-owned).
const int shift = (int) n_tokens - K;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
for (int t = 0; t < n_tokens; t++) {
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
@@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q,
// Write state back to global memory
if constexpr (keep_rs_t) {
const int target_slot = t - shift;
const int target_slot = (int) n_tokens - 1 - t;
if (target_slot >= 0 && target_slot < K) {
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
#pragma unroll
@@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
dpct::queue_ptr stream = ctx.stream();
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const int K = (int) src_state->ne[1];
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
const int K = ggml_get_op_params_i32(dst, 0);
const bool keep_rs = K > 1;
if (kda) {
+4 -4
View File
@@ -11528,7 +11528,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_beta = dst->src[4];
const ggml_tensor * src_state = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
@@ -11537,8 +11536,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
const uint32_t K = (uint32_t)src_state->ne[1];
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0);
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
@@ -17954,7 +17953,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
src_clone[2], src_clone[3], src_clone[4], src_clone[5],
ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
@@ -102,8 +102,8 @@ void main() {
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
const uint state_in_base = (seq_id * K * H + head_id) * state_size;
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
const uint state_in_base = (seq_id * H + head_id) * state_size;
// output state layout per slot: same per-(seq,head) offset as the single-slot case.
const uint state_out_base = (seq_id * H + head_id) * state_size;
const uint state_size_per_snap = state_size * H * n_seqs;
@@ -113,9 +113,8 @@ void main() {
s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]);
}
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
const int shift = int(n_tokens) - int(K);
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@@ -172,7 +171,7 @@ void main() {
attn_off += S_V * H;
if (K > 1u) {
const int target_slot = int(t) - shift;
const int target_slot = int(n_tokens) - 1 - int(t);
if (target_slot >= 0 && target_slot < int(K)) {
const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base;
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
+1 -1
View File
@@ -1245,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
const uint32_t h = (uint32_t) src2->ne[1];
const uint32_t n_tokens = (uint32_t) src2->ne[2];
const uint32_t n_seqs = (uint32_t) src2->ne[3];
const uint32_t K = (uint32_t) src5->ne[1];
const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0);
const float scale = 1.0f / sqrtf((float) s_v);
uint32_t scale_u32;
memcpy(&scale_u32, &scale, sizeof(scale_u32));
@@ -63,10 +63,10 @@ fn main(
let iq3 = seq_id / params.rq3;
let state_size = S_V * S_V;
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
let state_in_base = (seq_id * params.h + head_id) * state_size;
let state_out_base = (seq_id * params.h + head_id) * state_size;
let state_size_per_snap = state_size * params.h * params.n_seqs;
let shift = i32(params.n_tokens) - i32(params.K);
var state: array<f32, S_V>;
for (var i = 0u; i < S_V; i++) {
@@ -128,7 +128,8 @@ fn main(
attn_off += S_V * params.h;
if (params.K > 1u) {
let target_slot = i32(t) - shift;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
let target_slot = i32(params.n_tokens) - 1 - i32(t);
if (target_slot >= 0 && target_slot < i32(params.K)) {
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
for (var i = 0u; i < S_V; i++) {
+10 -6
View File
@@ -6223,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net(
struct ggml_tensor * v,
struct ggml_tensor * g,
struct ggml_tensor * beta,
struct ggml_tensor * state) {
struct ggml_tensor * state,
int64_t K) {
GGML_ASSERT(ggml_is_contiguous_rows(q));
GGML_ASSERT(ggml_is_contiguous_rows(k));
GGML_ASSERT(ggml_is_contiguous_rows(v));
@@ -6247,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net(
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
GGML_ASSERT(beta->ne[0] == 1);
// state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count.
GGML_ASSERT(state->ne[0] == S_v * S_v * H);
GGML_ASSERT(state->ne[2] == n_seqs);
GGML_ASSERT(state->ne[3] == 1);
const int64_t K = state->ne[1];
// state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param.
GGML_ASSERT(state->ne[0] == S_v);
GGML_ASSERT(state->ne[1] == S_v);
GGML_ASSERT(state->ne[2] == H);
GGML_ASSERT(state->ne[3] == n_seqs);
GGML_ASSERT(K >= 1);
const int64_t state_rows = K * S_v * n_seqs;
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, (int32_t) K);
result->op = GGML_OP_GATED_DELTA_NET;
result->src[0] = q;
result->src[1] = k;
+20 -21
View File
@@ -398,9 +398,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
// K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net.
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs);
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d);
// K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs].
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1);
if (n_tokens == 1) {
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
} else {
@@ -564,11 +563,8 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
const int64_t D = S_v * S_v * H_v;
const int64_t K = cparams.n_rs_seq + 1;
// TODO: remove pad + simplify
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
// state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output.
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K);
if (n_seq_tokens > 1) {
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
} else {
@@ -587,21 +583,24 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
cb(output, "attn_output", il);
const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
for (int64_t k_i = 0; k_i < K; ++k_i) {
const uint32_t cache_slot = (uint32_t) (K - 1 - k_i);
ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
S_v, S_v, H_v, n_seqs,
ggml_row_size(gdn_out->type, S_v),
ggml_row_size(gdn_out->type, S_v * S_v),
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
((size_t) cache_slot * mem_size + kv_head) * row_size);
// op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten
const int64_t n_written = std::min<int64_t>(n_seq_tokens, K);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
}
// write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i)
ggml_tensor * src = ggml_view_3d(ctx0, gdn_out,
D, n_seqs, n_written,
ggml_row_size(gdn_out->type, D),
ggml_row_size(gdn_out->type, state_size_per_snap),
ggml_row_size(gdn_out->type, attn_score_elems));
ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all,
D, n_seqs, n_written,
ssm_states_all->nb[1],
(size_t) mem_size * row_size,
(size_t) kv_head * row_size);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
return output;
}
+1 -1
View File
@@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context {
ggml_tensor * s,
int il);
// use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs))
// use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs])
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
ggml_tensor * q,
ggml_tensor * k,
+4 -3
View File
@@ -3896,14 +3896,14 @@ struct test_gated_delta_net : public test_case {
const int64_t g_ne0 = kda ? head_size : 1;
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs);
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_size, head_size, head_count * v_repeat, n_seqs);
ggml_set_name(g, "g");
ggml_set_name(beta, "beta");
ggml_set_name(state, "state");
// q/k are L2-normalised in qwen35/kimi-linear before delta_net
q = ggml_l2_norm(ctx, q, 1e-6f);
k = ggml_l2_norm(ctx, k, 1e-6f);
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, K);
return out;
}
@@ -9190,7 +9190,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region.
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots, ordered most-recent-first
// (slot 0 = final state, slot s = state s tokens back).
// exact-match cases (K == n_seq_tokens):
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4));
+1 -1
View File
@@ -81,7 +81,7 @@ if (LLAMA_BUILD_BORINGSSL)
target_link_libraries(${TARGET} PUBLIC ssl crypto)
elseif (LLAMA_BUILD_LIBRESSL)
set(LIBRESSL_VERSION "4.3.1" CACHE STRING "LibreSSL version")
set(LIBRESSL_VERSION "4.3.2" CACHE STRING "LibreSSL version")
message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}")