mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-13 01:06:45 +02:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 18ef86ecec | |||
| 1bfbdb134e | |||
| 68f30663cf | |||
| db94854ff5 | |||
| ac4cddeb0d | |||
| e95dae18d6 | |||
| d2462f8f7a | |||
| fb83cc9a07 | |||
| 039e20a2db | |||
| d2e22ed975 | |||
| 76da2450a4 | |||
| d73cd07674 | |||
| e25a32e98c | |||
| 483609509d | |||
| 49f3542190 | |||
| d6d0ce8215 |
@@ -504,7 +504,7 @@ jobs:
|
||||
needs: [check-release]
|
||||
if: ${{ needs.check-release.outputs.should_release == 'true' }}
|
||||
|
||||
runs-on: windows-2025
|
||||
runs-on: windows-2025-vs2026
|
||||
|
||||
permissions:
|
||||
actions: write
|
||||
@@ -535,12 +535,12 @@ jobs:
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.21
|
||||
with:
|
||||
key: release-windows-2025-${{ matrix.arch }}-cpu
|
||||
key: release-windows-2025-vs2026-${{ matrix.arch }}-cpu
|
||||
|
||||
- name: Build
|
||||
shell: cmd
|
||||
run: |
|
||||
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }}
|
||||
call "C:\Program Files\Microsoft Visual Studio\18\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }}
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
@@ -554,12 +554,12 @@ jobs:
|
||||
- name: ccache-clear
|
||||
uses: ./.github/actions/ccache-clear
|
||||
with:
|
||||
key: release-windows-2025-${{ matrix.arch }}-cpu
|
||||
key: release-windows-2025-vs2026-${{ matrix.arch }}-cpu
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
|
||||
Copy-Item "C:\Program Files\Microsoft Visual Studio\18\Enterprise\VC\Redist\MSVC\14.51.36231\debug_nonredist\${{ matrix.arch }}\Microsoft.VC145.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
|
||||
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
|
||||
- name: Install komac
|
||||
run: |
|
||||
cargo binstall komac@2.15.0 -y
|
||||
cargo binstall komac@2.16.0 -y
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
|
||||
+12
-3
@@ -1647,11 +1647,12 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
data.thinking_start_tag = THINK_START;
|
||||
data.thinking_end_tag = THINK_END;
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
|
||||
// Gate by reasoning format and whether the template supports <think>
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE &&
|
||||
tmpl.source().find(THINK_START) != std::string::npos;
|
||||
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
@@ -1674,6 +1675,10 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
}
|
||||
|
||||
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
if (has_response_format) {
|
||||
auto response_format = p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema));
|
||||
return generation_prompt + reasoning + response_format + end;
|
||||
}
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + end;
|
||||
}
|
||||
auto tool_calls = p.rule("tool-calls",
|
||||
@@ -1692,13 +1697,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
|
||||
@@ -843,7 +843,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
common_speculative_impl_ngram_map_k(
|
||||
const common_ngram_map & config,
|
||||
uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
|
||||
: common_speculative_impl(config.key_only ? COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K
|
||||
: COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, n_seq)
|
||||
{
|
||||
for (uint32_t i = 0; i < n_seq; i++) {
|
||||
this->config.push_back(config);
|
||||
|
||||
+12
-5
@@ -2553,10 +2553,16 @@ extern "C" {
|
||||
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
|
||||
//
|
||||
// state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs):
|
||||
// K == 1: output carries the final state only.
|
||||
// K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K)
|
||||
// per-token snapshots into the trailing slots
|
||||
// tensor shapes (S_k == S_v, H_v % H_k == 0):
|
||||
// q, k : [S_k, H_k, n_tokens, n_seqs]
|
||||
// v : [S_v, H_v, n_tokens, n_seqs]
|
||||
// g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA)
|
||||
// beta : [1, H_v, n_tokens, n_seqs]
|
||||
// state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0
|
||||
//
|
||||
// the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state
|
||||
// snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1
|
||||
// keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written.
|
||||
GGML_API struct ggml_tensor * ggml_gated_delta_net(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
@@ -2564,7 +2570,8 @@ extern "C" {
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state);
|
||||
struct ggml_tensor * state,
|
||||
int64_t K);
|
||||
|
||||
// custom operators
|
||||
|
||||
|
||||
@@ -776,8 +776,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(
|
||||
GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1);
|
||||
// state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0,
|
||||
// so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2).
|
||||
// state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2,
|
||||
// so a head-aligned split on the input cache lands on axis 2 here.
|
||||
GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0);
|
||||
return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1};
|
||||
};
|
||||
|
||||
@@ -2948,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan(
|
||||
case GGML_OP_GATED_DELTA_NET:
|
||||
{
|
||||
const int64_t S_v = node->src[2]->ne[0];
|
||||
const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs)
|
||||
const int64_t K = ggml_get_op_params_i32(node, 0);
|
||||
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
|
||||
cur = per_thread * sizeof(float) * n_tasks;
|
||||
} break;
|
||||
|
||||
@@ -10624,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
|
||||
const bool kda = (neg0 == S_v);
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int64_t K = src_state->ne[1];
|
||||
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
||||
const int64_t K = ggml_get_op_params_i32(dst, 0);
|
||||
GGML_ASSERT(K >= 1);
|
||||
// per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride)
|
||||
const int64_t state_seq_stride = src_state->nb[2] / sizeof(float);
|
||||
// per-seq stride in floats (seq s starts at state + s * seq_stride)
|
||||
const int64_t state_seq_stride = src_state->nb[3] / sizeof(float);
|
||||
|
||||
const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0);
|
||||
const int ith = params->ith;
|
||||
@@ -10644,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
float * attn_out_base = (float *)dst->data;
|
||||
float * state_out_base = (float *)dst->data + attn_score_elems;
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int64_t shift = n_tokens - K;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
const float * state_in_base = (const float *)src_state->data;
|
||||
|
||||
@@ -10674,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
: state_out_base + (iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
// copy input state into the working buffer and operate in-place
|
||||
// state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride.
|
||||
// state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride.
|
||||
const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v;
|
||||
memcpy(s_out, s_in, S_v * S_v * sizeof(float));
|
||||
|
||||
@@ -10727,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
||||
attn_data += S_v * H; // advance to next token
|
||||
|
||||
if (K > 1) {
|
||||
const int64_t target_slot = t - shift;
|
||||
const int64_t target_slot = n_tokens - 1 - t;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state_o = state_out_base + target_slot * state_size_per_snap +
|
||||
(iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
@@ -39,9 +39,9 @@ gated_delta_net_cuda(const float * q,
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
|
||||
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
state += state_out_offset;
|
||||
curr_state += state_in_offset + col * S_v;
|
||||
@@ -143,12 +143,10 @@ gated_delta_net_cuda(const float * q,
|
||||
attn_data += S_v * H;
|
||||
|
||||
if constexpr (keep_rs_t) {
|
||||
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
|
||||
// are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int) n_tokens - K;
|
||||
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
const int target_slot = t - shift;
|
||||
const int target_slot = (int) n_tokens - 1 - t;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
#pragma unroll
|
||||
@@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = (int) src_state->ne[1];
|
||||
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
||||
const int K = ggml_get_op_params_i32(dst, 0);
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
if (kda) {
|
||||
|
||||
@@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1)
|
||||
__shared__ CubTempStorage cub_temp_storage;
|
||||
|
||||
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
|
||||
__syncthreads();
|
||||
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
|
||||
#else
|
||||
const int stride_s0 = src0_nb2 / sizeof(float);
|
||||
@@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1)
|
||||
regs0[n] = state;
|
||||
}
|
||||
y_block[i * stride_y + threadIdx.x] = sumf;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#ifdef USE_CUB
|
||||
@@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
GGML_ASSERT(n_group == 1);
|
||||
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
||||
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
||||
if (d_state == 16) {
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
|
||||
switch (n_tok)
|
||||
{
|
||||
case 1:
|
||||
|
||||
@@ -2538,7 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
|
||||
const int64_t H = v->ne[1];
|
||||
const int64_t n_tokens = v->ne[2];
|
||||
const int64_t n_seqs = v->ne[3];
|
||||
const int64_t K = state->ne[1];
|
||||
const int64_t K = ggml_get_op_params_i32(op, 0);
|
||||
|
||||
if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) {
|
||||
return false;
|
||||
@@ -2551,7 +2551,8 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses
|
||||
if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) {
|
||||
// state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0.
|
||||
if (ggml_nelements(state) != S_v * S_v * H * n_seqs) {
|
||||
return false;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
|
||||
|
||||
@@ -584,7 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
const uint32_t K = octx->op_params[0];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
@@ -618,9 +618,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
||||
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
|
||||
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
|
||||
|
||||
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
|
||||
const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
|
||||
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
|
||||
const int64_t shift = (int64_t) n_tokens - (int64_t) K;
|
||||
|
||||
uint32_t ir_prefetch = ith;
|
||||
int spad_idx = 0;
|
||||
@@ -630,7 +629,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
|
||||
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
|
||||
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
|
||||
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
|
||||
// final state lands in snapshot slot 0 (most-recent-first ordering)
|
||||
float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
|
||||
|
||||
// Push dummy write-back
|
||||
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
|
||||
@@ -661,7 +661,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
|
||||
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
|
||||
|
||||
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
// final state lands in snapshot slot 0 (most-recent-first ordering)
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v;
|
||||
|
||||
@@ -792,7 +793,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo
|
||||
}
|
||||
|
||||
if (K > 1) {
|
||||
const int64_t target_slot = (int64_t) t - shift;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t;
|
||||
if (target_slot >= 0 && target_slot < (int64_t) K) {
|
||||
float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
if (curr_state_o != s_out) {
|
||||
@@ -844,7 +846,6 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t S_v = v->ne[0];
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
|
||||
const uint32_t total_rows = H * n_seqs;
|
||||
if (ith >= total_rows) {
|
||||
@@ -878,8 +879,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
||||
struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3);
|
||||
struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3);
|
||||
|
||||
const uint64_t state_seq_stride = state->nb[2] / sizeof(float);
|
||||
const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs;
|
||||
const uint64_t state_seq_stride = state->nb[3] / sizeof(float);
|
||||
|
||||
uint32_t ir_prefetch = ith;
|
||||
int spad_idx = 0;
|
||||
@@ -889,7 +889,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H);
|
||||
const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H);
|
||||
const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v;
|
||||
float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
|
||||
// final state lands in snapshot slot 0 (most-recent-first ordering)
|
||||
float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v;
|
||||
|
||||
// Push dummy write-back
|
||||
dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]),
|
||||
@@ -920,7 +921,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo
|
||||
const uint32_t iq3 = fastdiv(iv3, &fd_rq3);
|
||||
const uint32_t ik3 = fastdiv(iv3, &fd_rk3);
|
||||
|
||||
float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
// final state lands in snapshot slot 0 (most-recent-first ordering)
|
||||
float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v;
|
||||
|
||||
float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v;
|
||||
|
||||
@@ -1097,7 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
|
||||
const uint32_t H = v->ne[1];
|
||||
const uint32_t n_tokens = v->ne[2];
|
||||
const uint32_t n_seqs = v->ne[3];
|
||||
const uint32_t K = state->ne[1];
|
||||
const uint32_t K = octx->op_params[0];
|
||||
|
||||
if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
@@ -1110,7 +1112,8 @@ int op_gated_delta_net(struct htp_ops_context * octx) {
|
||||
(n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) {
|
||||
// state holds s0 only: [S_v, S_v, H, n_seqs]
|
||||
if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) {
|
||||
|
||||
@@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(
|
||||
const int ne20 = op->src[2]->ne[0]; // S_v
|
||||
const int ne21 = op->src[2]->ne[1]; // H
|
||||
const int ne30 = op->src[3]->ne[0]; // G
|
||||
// state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = op->src[5]->ne[1];
|
||||
// state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0.
|
||||
const int K = ggml_get_op_params_i32(op, 0);
|
||||
|
||||
const int nsg = op->src[2]->ne[0]/32;
|
||||
|
||||
|
||||
@@ -2599,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl(
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
|
||||
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
||||
|
||||
float ls[NSG];
|
||||
@@ -2620,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl(
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int)args.ne22 - (int)K;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
// output state base offset: after attention scores
|
||||
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
||||
@@ -2680,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl(
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1) {
|
||||
const int target_slot = (int)t - shift;
|
||||
const int target_slot = (int)args.ne22 - 1 - (int)t;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
|
||||
@@ -17750,7 +17750,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) {
|
||||
const cl_uint H_v = (cl_uint) src_v->ne[1];
|
||||
const cl_uint n_tokens = (cl_uint) src_v->ne[2];
|
||||
const cl_uint n_seqs = (cl_uint) src_v->ne[3];
|
||||
const cl_uint K = (cl_uint) src_state->ne[1];
|
||||
const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
int si;
|
||||
switch (S_v) {
|
||||
|
||||
@@ -123,7 +123,8 @@ kernel void kernel_gated_delta_net(
|
||||
const uint iq3 = seq_id / rq3; // seq index for Q and K
|
||||
|
||||
const uint state_size = S_V * S_V;
|
||||
const uint state_base = (seq_id * K * H_v + head_id) * state_size;
|
||||
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
|
||||
const uint state_base = (seq_id * H_v + head_id) * state_size;
|
||||
const uint q_off_base = iq3 * sq3 + iq1 * sq1;
|
||||
const uint v_off_base = seq_id * sv3 + head_id * sv1;
|
||||
const uint gb_off_base = seq_id * sb3 + head_id * sb1;
|
||||
@@ -143,7 +144,8 @@ kernel void kernel_gated_delta_net(
|
||||
}
|
||||
}
|
||||
|
||||
const int shift = (int)n_tokens - (int)K;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V;
|
||||
|
||||
for (uint t = 0; t < n_tokens; t++) {
|
||||
@@ -219,7 +221,7 @@ kernel void kernel_gated_delta_net(
|
||||
attn_off += S_V * H_v;
|
||||
|
||||
if (K > 1u) {
|
||||
const int target_slot = (int)t - shift;
|
||||
const int target_slot = (int)n_tokens - 1 - (int)t;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
#pragma unroll
|
||||
for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) {
|
||||
|
||||
@@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q,
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
|
||||
// input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v;
|
||||
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
state += state_out_offset;
|
||||
@@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q,
|
||||
s_shard[r] = curr_state[i];
|
||||
}
|
||||
|
||||
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
|
||||
// are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = (int) n_tokens - K;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
|
||||
@@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q,
|
||||
|
||||
// Write state back to global memory
|
||||
if constexpr (keep_rs_t) {
|
||||
const int target_slot = t - shift;
|
||||
const int target_slot = (int) n_tokens - 1 - t;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
#pragma unroll
|
||||
@@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const int K = (int) src_state->ne[1];
|
||||
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
||||
const int K = ggml_get_op_params_i32(dst, 0);
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
if (kda) {
|
||||
|
||||
@@ -3394,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
lut_size = 2*2048 + 4*2048;
|
||||
// Regular matmul uses the compact uint16_t IQ1 grid; the expanded
|
||||
// uint32_t grid is only enabled for the q8_1/int-dot vector path.
|
||||
lut_size = 2*2048;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
lut_size = 8*256;
|
||||
@@ -11526,7 +11528,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
const ggml_tensor * src_q = dst->src[0];
|
||||
const ggml_tensor * src_v = dst->src[2];
|
||||
const ggml_tensor * src_beta = dst->src[4];
|
||||
const ggml_tensor * src_state = dst->src[5];
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
@@ -11535,8 +11536,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t)src_v->ne[3];
|
||||
|
||||
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
|
||||
const uint32_t K = (uint32_t)src_state->ne[1];
|
||||
// K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
|
||||
const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const uint32_t s_off = S_v * H * n_tokens * n_seqs;
|
||||
|
||||
@@ -17952,7 +17953,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
|
||||
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
|
||||
src_clone[2], src_clone[3], src_clone[4], src_clone[5],
|
||||
ggml_get_op_params_i32(tensor, 0));
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
src_clone[0]->flags = tensor->src[0]->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
|
||||
@@ -102,8 +102,8 @@ void main() {
|
||||
const uint iq3 = seq_id / rq3;
|
||||
|
||||
const uint state_size = S_V * S_V;
|
||||
// input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0.
|
||||
const uint state_in_base = (seq_id * K * H + head_id) * state_size;
|
||||
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
|
||||
const uint state_in_base = (seq_id * H + head_id) * state_size;
|
||||
// output state layout per slot: same per-(seq,head) offset as the single-slot case.
|
||||
const uint state_out_base = (seq_id * H + head_id) * state_size;
|
||||
const uint state_size_per_snap = state_size * H * n_seqs;
|
||||
@@ -113,9 +113,8 @@ void main() {
|
||||
s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]);
|
||||
}
|
||||
|
||||
// snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last
|
||||
// n_tokens slots are written; earlier slots are left untouched (caller-owned).
|
||||
const int shift = int(n_tokens) - int(K);
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
||||
@@ -172,7 +171,7 @@ void main() {
|
||||
attn_off += S_V * H;
|
||||
|
||||
if (K > 1u) {
|
||||
const int target_slot = int(t) - shift;
|
||||
const int target_slot = int(n_tokens) - 1 - int(t);
|
||||
if (target_slot >= 0 && target_slot < int(K)) {
|
||||
const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base;
|
||||
[[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
|
||||
#define MMQ
|
||||
#define NEEDS_IQ1S_GRID_GPU
|
||||
#define B_TYPE block_q8_1_x4
|
||||
|
||||
#include "mul_mat_vec_base.glsl"
|
||||
|
||||
@@ -598,9 +598,10 @@ const uint[1024] iq1s_grid_const = {
|
||||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
};
|
||||
|
||||
#if defined(NEEDS_IQ1S_GRID_GPU)
|
||||
// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
|
||||
// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
|
||||
// and 0xF0F0F0F0).
|
||||
// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path.
|
||||
const uint32_t[2048] iq1s_grid_gpu_const = {
|
||||
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
|
||||
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
|
||||
@@ -859,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = {
|
||||
0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
|
||||
0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
|
||||
};
|
||||
#endif
|
||||
|
||||
shared uint16_t iq1s_grid[2048];
|
||||
#if defined(NEEDS_IQ1S_GRID_GPU)
|
||||
shared uint32_t iq1s_grid_gpu[2048];
|
||||
#endif
|
||||
|
||||
#define NEEDS_INIT_IQ_SHMEM
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
@@ -875,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize)
|
||||
iq1s_grid[2*idx+1] = g.y;
|
||||
}
|
||||
}
|
||||
#if defined(NEEDS_IQ1S_GRID_GPU)
|
||||
[[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
|
||||
uint idx = i + gl_LocalInvocationIndex.x;
|
||||
if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
|
||||
iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1245,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx,
|
||||
const uint32_t h = (uint32_t) src2->ne[1];
|
||||
const uint32_t n_tokens = (uint32_t) src2->ne[2];
|
||||
const uint32_t n_seqs = (uint32_t) src2->ne[3];
|
||||
const uint32_t K = (uint32_t) src5->ne[1];
|
||||
const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0);
|
||||
const float scale = 1.0f / sqrtf((float) s_v);
|
||||
uint32_t scale_u32;
|
||||
memcpy(&scale_u32, &scale, sizeof(scale_u32));
|
||||
|
||||
@@ -63,10 +63,10 @@ fn main(
|
||||
let iq3 = seq_id / params.rq3;
|
||||
|
||||
let state_size = S_V * S_V;
|
||||
let state_in_base = (seq_id * params.K * params.h + head_id) * state_size;
|
||||
// input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D.
|
||||
let state_in_base = (seq_id * params.h + head_id) * state_size;
|
||||
let state_out_base = (seq_id * params.h + head_id) * state_size;
|
||||
let state_size_per_snap = state_size * params.h * params.n_seqs;
|
||||
let shift = i32(params.n_tokens) - i32(params.K);
|
||||
|
||||
var state: array<f32, S_V>;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
@@ -128,7 +128,8 @@ fn main(
|
||||
attn_off += S_V * params.h;
|
||||
|
||||
if (params.K > 1u) {
|
||||
let target_slot = i32(t) - shift;
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
let target_slot = i32(params.n_tokens) - 1 - i32(t);
|
||||
if (target_slot >= 0 && target_slot < i32(params.K)) {
|
||||
let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base;
|
||||
for (var i = 0u; i < S_V; i++) {
|
||||
|
||||
+10
-6
@@ -6223,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net(
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * beta,
|
||||
struct ggml_tensor * state) {
|
||||
struct ggml_tensor * state,
|
||||
int64_t K) {
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(q));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(k));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(v));
|
||||
@@ -6247,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net(
|
||||
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
||||
GGML_ASSERT(beta->ne[0] == 1);
|
||||
|
||||
// state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count.
|
||||
GGML_ASSERT(state->ne[0] == S_v * S_v * H);
|
||||
GGML_ASSERT(state->ne[2] == n_seqs);
|
||||
GGML_ASSERT(state->ne[3] == 1);
|
||||
const int64_t K = state->ne[1];
|
||||
// state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param.
|
||||
GGML_ASSERT(state->ne[0] == S_v);
|
||||
GGML_ASSERT(state->ne[1] == S_v);
|
||||
GGML_ASSERT(state->ne[2] == H);
|
||||
GGML_ASSERT(state->ne[3] == n_seqs);
|
||||
GGML_ASSERT(K >= 1);
|
||||
const int64_t state_rows = K * S_v * n_seqs;
|
||||
const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, (int32_t) K);
|
||||
|
||||
result->op = GGML_OP_GATED_DELTA_NET;
|
||||
result->src[0] = q;
|
||||
result->src[1] = k;
|
||||
|
||||
@@ -272,7 +272,8 @@ class Keys:
|
||||
CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
|
||||
CHAT_TEMPLATES = "tokenizer.chat_templates"
|
||||
# Normalizer constants
|
||||
NORMALIZER_LOWERCASE = "tokenizer.ggml.normalizer.lowercase"
|
||||
NORMALIZER_LOWERCASE = "tokenizer.ggml.normalizer.lowercase"
|
||||
NORMALIZER_STRIP_ACCENTS = "tokenizer.ggml.normalizer.strip_accents"
|
||||
# FIM/Infill special tokens constants
|
||||
FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id"
|
||||
FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id"
|
||||
|
||||
@@ -1124,6 +1124,9 @@ class GGUFWriter:
|
||||
def add_normalizer_lowercase(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.NORMALIZER_LOWERCASE, value)
|
||||
|
||||
def add_normalizer_strip_accents(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.NORMALIZER_STRIP_ACCENTS, value)
|
||||
|
||||
def add_eot_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
|
||||
|
||||
|
||||
+19
-4
@@ -53,6 +53,7 @@ class SpecialVocab:
|
||||
special_token_ids: dict[str, int]
|
||||
chat_template: str | Sequence[Mapping[str, str]] | None
|
||||
normalizer_lowercase: bool | None
|
||||
normalizer_strip_accents: bool | None
|
||||
|
||||
def __init__(
|
||||
self, path: str | os.PathLike[str], load_merges: bool = False,
|
||||
@@ -66,6 +67,7 @@ class SpecialVocab:
|
||||
self.merges = []
|
||||
self.chat_template = None
|
||||
self.normalizer_lowercase = None
|
||||
self.normalizer_strip_accents = None
|
||||
if special_token_types is not None:
|
||||
self.special_token_types = special_token_types
|
||||
else:
|
||||
@@ -108,6 +110,10 @@ class SpecialVocab:
|
||||
if not quiet:
|
||||
logger.info(f'Setting normalizer_lowercase to {self.normalizer_lowercase}')
|
||||
gw.add_normalizer_lowercase(self.normalizer_lowercase)
|
||||
if self.normalizer_strip_accents is not None:
|
||||
if not quiet:
|
||||
logger.info(f'Setting normalizer_strip_accents to {self.normalizer_strip_accents}')
|
||||
gw.add_normalizer_strip_accents(self.normalizer_strip_accents)
|
||||
|
||||
def _load(self, path: Path) -> None:
|
||||
self._try_load_from_tokenizer_json(path)
|
||||
@@ -155,17 +161,21 @@ class SpecialVocab:
|
||||
def _parse_normalizer(self, normalizer: dict) -> None:
|
||||
# ref: https://huggingface.co/docs/tokenizers/api/normalizers
|
||||
#
|
||||
# Detects lowercase normalization in three possible formats:
|
||||
# 1. Standalone: {"type": "Lowercase"}
|
||||
# 2. BertNormalizer attribute: {"type": "BertNormalizer", "lowercase": true, ...}
|
||||
# 3. Nested in Sequence: {"type": "Sequence", "normalizers": [...]}
|
||||
# Extracts normalizer flags from three possible formats:
|
||||
# 1. Standalone: {"type": "Lowercase"}
|
||||
# 2. BertNormalizer attrs: {"type": "BertNormalizer", ...}
|
||||
# 3. Nested in Sequence: {"type": "Sequence", "normalizers": [...]}
|
||||
|
||||
normalizer_type = normalizer.get('type')
|
||||
if normalizer_type == 'Lowercase':
|
||||
self.normalizer_lowercase = True
|
||||
elif normalizer_type == 'StripAccents':
|
||||
self.normalizer_strip_accents = True
|
||||
elif normalizer_type == 'BertNormalizer':
|
||||
if 'lowercase' in normalizer:
|
||||
self.normalizer_lowercase = normalizer['lowercase']
|
||||
if 'strip_accents' in normalizer:
|
||||
self.normalizer_strip_accents = normalizer['strip_accents']
|
||||
elif normalizer_type == 'Sequence':
|
||||
for norm in normalizer.get('normalizers', []):
|
||||
self._parse_normalizer(norm)
|
||||
@@ -246,6 +256,11 @@ class SpecialVocab:
|
||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_bos = special_first
|
||||
elif special_first not in (special_bos, special_cls):
|
||||
if not special_bos:
|
||||
tokenizer_config['bos_token'] = special_bos = special_first
|
||||
if not special_cls:
|
||||
tokenizer_config['cls_token'] = special_cls = special_first
|
||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||
if special_first not in (special_bos, special_cls):
|
||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||
|
||||
+34
-33
@@ -299,39 +299,40 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
|
||||
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
|
||||
{ LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
|
||||
{ LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
|
||||
{ LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },
|
||||
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
|
||||
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
||||
{ LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
|
||||
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" },
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" },
|
||||
{ LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" },
|
||||
{ LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" },
|
||||
{ LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" },
|
||||
{ LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
|
||||
{ LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },
|
||||
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
|
||||
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
|
||||
{ LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
|
||||
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" },
|
||||
{ LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, "tokenizer.ggml.normalizer.strip_accents" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" },
|
||||
{ LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" },
|
||||
|
||||
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
|
||||
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
|
||||
|
||||
@@ -314,6 +314,7 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE,
|
||||
LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
|
||||
+3
-3
@@ -1873,9 +1873,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
res->t_inp_embd = cur;
|
||||
|
||||
// For Granite architecture
|
||||
// NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be
|
||||
// multimodal inputs that should not be scaled.
|
||||
if (ubatch.token && hparams.f_embedding_scale != 0.0f) {
|
||||
// NOTE: For deepstack models, only apply scale to token inputs (ie text-only input).
|
||||
// Raw embeddings are assumed to be multimodal inputs that should not be scaled.
|
||||
if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) {
|
||||
if (!ggml_is_contiguous(cur)) {
|
||||
cur = ggml_cont(ctx0, cur);
|
||||
}
|
||||
|
||||
+23
-12
@@ -764,7 +764,7 @@ struct llm_tokenizer_wpm_session {
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||
// normalize and split by whitespace
|
||||
std::vector<std::string> words = preprocess(text, vocab.get_normalizer_lowercase());
|
||||
std::vector<std::string> words = preprocess(text, vocab.get_normalizer_opts());
|
||||
// bos token prepended already
|
||||
|
||||
// find the longest tokens that form the words
|
||||
@@ -809,11 +809,14 @@ struct llm_tokenizer_wpm_session {
|
||||
}
|
||||
|
||||
// TODO: reduce string copies by using cpts_offs array
|
||||
static std::vector<std::string> preprocess(const std::string & text, bool lowercase) {
|
||||
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
||||
static std::vector<std::string> preprocess(const std::string & text, const llama_vocab::normalizer_options & normalizer_opts) {
|
||||
std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text);
|
||||
if (normalizer_opts.strip_accents) {
|
||||
cpts = unicode_cpts_normalize_nfd(cpts);
|
||||
}
|
||||
std::vector<std::string> words(1, "");
|
||||
|
||||
for (const uint32_t cpt : cpts_nfd) {
|
||||
for (const uint32_t cpt : cpts) {
|
||||
const auto flags = unicode_cpt_flags_from_cpt(cpt);
|
||||
|
||||
if (flags.is_whitespace) {
|
||||
@@ -828,7 +831,11 @@ struct llm_tokenizer_wpm_session {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt);
|
||||
if (normalizer_opts.strip_accents && flags.is_accent_mark) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string s = unicode_cpt_to_utf8(normalizer_opts.lowercase ? unicode_tolower(cpt) : cpt);
|
||||
if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
|
||||
if (words.back().size()) { // finish previous word if any
|
||||
words.emplace_back();
|
||||
@@ -1692,7 +1699,7 @@ struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session {
|
||||
llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) override {
|
||||
const bool lowercase = vocab.get_normalizer_lowercase();
|
||||
const bool lowercase = vocab.get_normalizer_opts().lowercase;
|
||||
|
||||
std::string segment;
|
||||
auto flush = [&]() {
|
||||
@@ -1797,7 +1804,9 @@ struct llama_vocab::impl {
|
||||
bool remove_extra_whitespaces = false;
|
||||
bool escape_whitespaces = true;
|
||||
bool treat_whitespace_as_suffix = false;
|
||||
bool normalizer_lowercase = true; // Lowercase normalizer (tokenizer.json)
|
||||
|
||||
// BertNormalizer options
|
||||
llama_vocab::normalizer_options normalizer_opts;
|
||||
|
||||
std::unordered_map<std::string, llama_token> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
@@ -2172,7 +2181,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "whitespace") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE;
|
||||
normalizer_lowercase = false;
|
||||
normalizer_opts.lowercase = false;
|
||||
} else if (
|
||||
tokenizer_pre == "refact") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
|
||||
@@ -2532,8 +2541,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
}
|
||||
}
|
||||
|
||||
// Lowercase normalizer flag (consulted by WPM / whitespace BPE)
|
||||
ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false);
|
||||
// BertNormalizer options
|
||||
ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_opts.lowercase, false);
|
||||
normalizer_opts.strip_accents = normalizer_opts.lowercase;
|
||||
ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, normalizer_opts.strip_accents, false);
|
||||
|
||||
// suppress tokens
|
||||
{
|
||||
@@ -3969,8 +3980,8 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const {
|
||||
return pimpl->treat_whitespace_as_suffix;
|
||||
}
|
||||
|
||||
bool llama_vocab::get_normalizer_lowercase() const {
|
||||
return pimpl->normalizer_lowercase;
|
||||
const llama_vocab::normalizer_options & llama_vocab::get_normalizer_opts() const {
|
||||
return pimpl->normalizer_opts;
|
||||
}
|
||||
|
||||
const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const {
|
||||
|
||||
+7
-1
@@ -76,6 +76,12 @@ struct llama_vocab {
|
||||
llama_token_attr attr;
|
||||
};
|
||||
|
||||
struct normalizer_options {
|
||||
bool lowercase = true;
|
||||
bool strip_accents = true;
|
||||
// TODO: clean_text, handle_chinese_chars
|
||||
};
|
||||
|
||||
llama_vocab();
|
||||
~llama_vocab();
|
||||
|
||||
@@ -141,7 +147,7 @@ struct llama_vocab {
|
||||
bool get_remove_extra_whitespaces () const;
|
||||
bool get_escape_whitespaces () const;
|
||||
bool get_treat_whitespace_as_suffix() const;
|
||||
bool get_normalizer_lowercase () const;
|
||||
const normalizer_options & get_normalizer_opts() const;
|
||||
|
||||
const std::vector<llama_token> & get_suppress_tokens() const;
|
||||
|
||||
|
||||
@@ -398,9 +398,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
// K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net.
|
||||
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs);
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d);
|
||||
// K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs].
|
||||
ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1);
|
||||
if (n_tokens == 1) {
|
||||
cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
|
||||
} else {
|
||||
@@ -564,11 +563,8 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
||||
const int64_t D = S_v * S_v * H_v;
|
||||
const int64_t K = cparams.n_rs_seq + 1;
|
||||
|
||||
// TODO: remove pad + simplify
|
||||
ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs);
|
||||
ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0);
|
||||
|
||||
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad);
|
||||
// state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output.
|
||||
ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K);
|
||||
if (n_seq_tokens > 1) {
|
||||
cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il);
|
||||
} else {
|
||||
@@ -587,21 +583,24 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn(
|
||||
cb(output, "attn_output", il);
|
||||
|
||||
const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all);
|
||||
for (int64_t k_i = 0; k_i < K; ++k_i) {
|
||||
const uint32_t cache_slot = (uint32_t) (K - 1 - k_i);
|
||||
ggml_tensor * src = ggml_view_4d(ctx0, gdn_out,
|
||||
S_v, S_v, H_v, n_seqs,
|
||||
ggml_row_size(gdn_out->type, S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * S_v),
|
||||
ggml_row_size(gdn_out->type, S_v * S_v * H_v),
|
||||
ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap));
|
||||
|
||||
ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all,
|
||||
hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
|
||||
((size_t) cache_slot * mem_size + kv_head) * row_size);
|
||||
// op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten
|
||||
const int64_t n_written = std::min<int64_t>(n_seq_tokens, K);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
}
|
||||
// write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i)
|
||||
ggml_tensor * src = ggml_view_3d(ctx0, gdn_out,
|
||||
D, n_seqs, n_written,
|
||||
ggml_row_size(gdn_out->type, D),
|
||||
ggml_row_size(gdn_out->type, state_size_per_snap),
|
||||
ggml_row_size(gdn_out->type, attn_score_elems));
|
||||
|
||||
ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all,
|
||||
D, n_seqs, n_written,
|
||||
ssm_states_all->nb[1],
|
||||
(size_t) mem_size * row_size,
|
||||
(size_t) kv_head * row_size);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst));
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
+1
-1
@@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context {
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs))
|
||||
// use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs])
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
|
||||
@@ -3896,14 +3896,14 @@ struct test_gated_delta_net : public test_case {
|
||||
const int64_t g_ne0 = kda ? head_size : 1;
|
||||
ggml_tensor * g = ggml_new_tensor_4d(ctx, type, g_ne0, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * beta = ggml_new_tensor_4d(ctx, type, 1, head_count * v_repeat, n_seq_tokens, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_3d(ctx, type, head_size * v_repeat * head_size * head_count, K, n_seqs);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, type, head_size, head_size, head_count * v_repeat, n_seqs);
|
||||
ggml_set_name(g, "g");
|
||||
ggml_set_name(beta, "beta");
|
||||
ggml_set_name(state, "state");
|
||||
// q/k are L2-normalised in qwen35/kimi-linear before delta_net
|
||||
q = ggml_l2_norm(ctx, q, 1e-6f);
|
||||
k = ggml_l2_norm(ctx, k, 1e-6f);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state);
|
||||
ggml_tensor * out = ggml_gated_delta_net(ctx, q, k, v, g, beta, state, K);
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -9190,7 +9190,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 33, 1, 1, false, true));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 100, 1, 1, false, true));
|
||||
|
||||
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots in the trailing K-token region.
|
||||
// K > 1: output keeps the last min(n_tokens, K) per-token snapshots, ordered most-recent-first
|
||||
// (slot 0 = final state, slot s = state s tokens back).
|
||||
// exact-match cases (K == n_seq_tokens):
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 2, 1, 1, false, false, /*K=*/2));
|
||||
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, false, /*K=*/4));
|
||||
|
||||
+30
-15
@@ -314,11 +314,17 @@ ggml_tensor * clip_graph::build_vit(
|
||||
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
|
||||
const build_vit_opts & opts
|
||||
) {
|
||||
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)
|
||||
const int64_t B = inp->ne[2];
|
||||
|
||||
if (learned_pos_embd) {
|
||||
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
||||
cb(inp, "pos_embed", -1);
|
||||
}
|
||||
|
||||
// flatten batch; unflatten again in attention
|
||||
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B);
|
||||
|
||||
ggml_tensor * inpL = inp;
|
||||
|
||||
// pre-layernorm
|
||||
@@ -348,20 +354,24 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cur = ggml_add(ctx0, cur, layer.qkv_b);
|
||||
}
|
||||
|
||||
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ 0);
|
||||
// Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos.
|
||||
Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ 0);
|
||||
|
||||
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, n_embd));
|
||||
Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ ggml_row_size(cur->type, n_embd));
|
||||
|
||||
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
||||
Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
|
||||
/* nb1 */ ggml_row_size(cur->type, d_head),
|
||||
/* nb2 */ cur->nb[1],
|
||||
/* nb3 */ cur->nb[1] * n_pos,
|
||||
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
|
||||
|
||||
if (layer.q_norm) {
|
||||
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
|
||||
@@ -406,9 +416,9 @@ ggml_tensor * clip_graph::build_vit(
|
||||
}
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
|
||||
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B);
|
||||
Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B);
|
||||
Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B);
|
||||
|
||||
if (norm_per_head) {
|
||||
if (layer.q_norm) {
|
||||
@@ -438,6 +448,7 @@ ggml_tensor * clip_graph::build_vit(
|
||||
cb(Vcur, "Vcur_normed", il);
|
||||
}
|
||||
|
||||
// build_attn returns a flat 2D [n_embd, n_pos*B]
|
||||
cur = build_attn(layer.o_w, layer.o_b,
|
||||
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
@@ -509,6 +520,10 @@ ggml_tensor * clip_graph::build_vit(
|
||||
if (model.post_ln_w) {
|
||||
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
|
||||
}
|
||||
|
||||
// restore the batch dim
|
||||
GGML_ASSERT(inpL->ne[1] % B == 0);
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B);
|
||||
return inpL;
|
||||
}
|
||||
|
||||
|
||||
@@ -91,7 +91,6 @@ add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/DeepSeek-OCR-GGUF:Q8_0" -p "Free OCR." --chat-template deepseek-ocr
|
||||
add_test_vision "ggml-org/dots.ocr-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanOCR-GGUF:Q8_0" -p "OCR"
|
||||
add_test_vision "ggml-org/HunyuanVL-4B-GGUF:Q8_0"
|
||||
add_test_vision "ggml-org/gemma-4-E2B-it-GGUF:Q8_0" --jinja
|
||||
|
||||
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
|
||||
|
||||
@@ -2046,6 +2046,9 @@ private:
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back();
|
||||
|
||||
// [TAG_CHECKPOINTS_FIX_POS_MIN]
|
||||
// TODO: here we incorrectly deterimne that the saved checkpoint data covers the [pos_min, pos_max] range
|
||||
// this is not true for SWA models: https://github.com/ggml-org/llama.cpp/pull/24411#issuecomment-4677983225
|
||||
cur.update_pos(slot.prompt.n_tokens() - n_tokens_cur, pos_min, pos_max);
|
||||
|
||||
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
@@ -2860,6 +2863,10 @@ private:
|
||||
// guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
|
||||
LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12,
|
||||
func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold);
|
||||
// workaround for [TAG_CHECKPOINTS_FIX_POS_MIN]
|
||||
if (cur.pos_max > pos_next) {
|
||||
return false;
|
||||
}
|
||||
return cur.pos_min < pos_min_thold || cur.pos_min == 0;
|
||||
}
|
||||
);
|
||||
|
||||
+14
-12
@@ -94,20 +94,22 @@ int llama_server(int argc, char ** argv) {
|
||||
const bool is_router_server = params.model.path.empty();
|
||||
common_params_print_info(params, !is_router_server);
|
||||
|
||||
// validate batch size for embeddings
|
||||
// embeddings require all tokens to be processed in a single ubatch
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/12836
|
||||
if (params.embedding && params.n_batch > params.n_ubatch) {
|
||||
SRV_WRN("embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", params.n_batch, params.n_ubatch);
|
||||
SRV_WRN("setting n_batch = n_ubatch = %d to avoid assertion failure\n", params.n_ubatch);
|
||||
params.n_batch = params.n_ubatch;
|
||||
}
|
||||
if (!is_router_server) {
|
||||
// validate batch size for embeddings
|
||||
// embeddings require all tokens to be processed in a single ubatch
|
||||
// see https://github.com/ggml-org/llama.cpp/issues/12836
|
||||
if (params.embedding && params.n_batch > params.n_ubatch) {
|
||||
SRV_WRN("embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", params.n_batch, params.n_ubatch);
|
||||
SRV_WRN("setting n_batch = n_ubatch = %d to avoid assertion failure\n", params.n_ubatch);
|
||||
params.n_batch = params.n_ubatch;
|
||||
}
|
||||
|
||||
if (params.n_parallel < 0) {
|
||||
SRV_INF("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
|
||||
if (params.n_parallel < 0) {
|
||||
SRV_INF("%s", "n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n");
|
||||
|
||||
params.n_parallel = 4;
|
||||
params.kv_unified = true;
|
||||
params.n_parallel = 4;
|
||||
params.kv_unified = true;
|
||||
}
|
||||
}
|
||||
|
||||
// for consistency between server router mode and single-model mode, we set the same model name as alias
|
||||
|
||||
@@ -46,7 +46,14 @@ export default ts.config(
|
||||
},
|
||||
{
|
||||
// Exclude generated build output and Storybook files from ESLint
|
||||
ignores: ['dist/**', 'build/**', '.svelte-kit/**', 'test-results/**', '.storybook/**/*']
|
||||
ignores: [
|
||||
'dist/**',
|
||||
'build/**',
|
||||
'.svelte-kit/**',
|
||||
'test-results/**',
|
||||
'.storybook/**/*',
|
||||
'src/lib/services/sandbox-worker.js'
|
||||
]
|
||||
},
|
||||
storybook.configs['flat/recommended']
|
||||
);
|
||||
|
||||
+48
-4
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2, Pencil, X } from '@lucide/svelte';
|
||||
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConfirmation } from '$lib/components/app';
|
||||
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
|
||||
@@ -52,6 +52,14 @@
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => conversation.pinned);
|
||||
});
|
||||
|
||||
let unpinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => !conversation.pinned);
|
||||
});
|
||||
|
||||
let selectedConversationHasDescendants = $derived.by(() => {
|
||||
if (!selectedConversation) return false;
|
||||
|
||||
@@ -199,6 +207,41 @@
|
||||
/>
|
||||
</Sidebar.Header>
|
||||
|
||||
{#if !isSearchModeActive && pinnedConversations.length > 0}
|
||||
<Sidebar.Group class="p-0 px-4">
|
||||
<Sidebar.GroupLabel>
|
||||
<div class="flex items-center gap-1">
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</Sidebar.GroupLabel>
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
{/if}
|
||||
|
||||
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
|
||||
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
|
||||
<Sidebar.GroupLabel>
|
||||
@@ -208,7 +251,7 @@
|
||||
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each conversationTree as { conversation, depth } (conversation.id)}
|
||||
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
@@ -216,7 +259,8 @@
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
@@ -228,7 +272,7 @@
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
{#if conversationTree.length === 0}
|
||||
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
|
||||
<div class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{searchQuery.length > 0
|
||||
|
||||
+15
-1
@@ -6,7 +6,9 @@
|
||||
Download,
|
||||
Loader2,
|
||||
Square,
|
||||
GitBranch
|
||||
GitBranch,
|
||||
Pin,
|
||||
PinOff
|
||||
} from '@lucide/svelte';
|
||||
import { DropdownMenuActions } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
@@ -57,6 +59,10 @@
|
||||
onStop?.(conversation.id);
|
||||
}
|
||||
|
||||
function handleTogglePin() {
|
||||
conversationsStore.toggleConversationPin(conversation.id);
|
||||
}
|
||||
|
||||
function handleGlobalEditEvent(event: Event) {
|
||||
const customEvent = event as CustomEvent<{ conversationId: string }>;
|
||||
|
||||
@@ -170,6 +176,14 @@
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
|
||||
@@ -37,6 +37,7 @@ export * from './model-id';
|
||||
export * from './precision';
|
||||
export * from './processing-info';
|
||||
export * from './routes';
|
||||
export * from './sandbox';
|
||||
export * from './settings-keys';
|
||||
export * from './settings-registry';
|
||||
export * from './supported-file-types';
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { JsonSchemaType, ToolCallType } from '$lib/enums';
|
||||
import type { OpenAIToolDefinition } from '$lib/types';
|
||||
|
||||
export const SANDBOX_TOOL_NAME = 'run_javascript';
|
||||
|
||||
export const SANDBOX_TIMEOUT_MS_DEFAULT = 10000;
|
||||
|
||||
export const SANDBOX_TIMEOUT_MS_MAX = 30000;
|
||||
|
||||
export const SANDBOX_OUTPUT_MAX_CHARS = 8192;
|
||||
|
||||
export const SANDBOX_EMPTY_OUTPUT = '(no output)';
|
||||
|
||||
export const SANDBOX_TRUNCATION_NOTICE = '[output truncated]';
|
||||
|
||||
export const SANDBOX_TOOL_DEFINITION: OpenAIToolDefinition = {
|
||||
type: ToolCallType.FUNCTION,
|
||||
function: {
|
||||
name: SANDBOX_TOOL_NAME,
|
||||
description:
|
||||
'Execute JavaScript in a sandboxed browser worker (no DOM, no page access). ' +
|
||||
'Top level await is supported. Use console.log to print intermediate values; ' +
|
||||
'a top level return statement is captured as the result.',
|
||||
parameters: {
|
||||
type: JsonSchemaType.OBJECT,
|
||||
properties: {
|
||||
code: {
|
||||
type: JsonSchemaType.STRING,
|
||||
description: 'JavaScript source to execute'
|
||||
},
|
||||
timeout_ms: {
|
||||
type: JsonSchemaType.NUMBER,
|
||||
description: `Execution timeout in milliseconds, default ${SANDBOX_TIMEOUT_MS_DEFAULT}, max ${SANDBOX_TIMEOUT_MS_MAX}`
|
||||
}
|
||||
},
|
||||
required: ['code']
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -69,6 +69,7 @@ export const SETTINGS_KEYS = {
|
||||
ENABLE_THINKING: 'enableThinking',
|
||||
SHOW_RAW_OUTPUT_SWITCH: 'showRawOutputSwitch',
|
||||
// PY_INTERPRETER_ENABLED: 'pyInterpreterEnabled',
|
||||
JS_SANDBOX_ENABLED: 'jsSandboxEnabled',
|
||||
CUSTOM_JSON: 'customJson',
|
||||
CUSTOM_CSS: 'customCss'
|
||||
} as const;
|
||||
|
||||
@@ -690,6 +690,14 @@ const SETTINGS_REGISTRY: Record<string, SettingsSectionEntry> = {
|
||||
paramType: SyncableParameterType.BOOLEAN
|
||||
}
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.JS_SANDBOX_ENABLED,
|
||||
label: 'JavaScript sandbox tool',
|
||||
help: 'Expose a run_javascript tool to the model. Code runs in a Web Worker inside a sandboxed iframe with an opaque origin, isolated from the WebUI and its API, with a hard timeout.',
|
||||
defaultValue: false,
|
||||
type: SettingsFieldType.CHECKBOX,
|
||||
section: SETTINGS_SECTION_SLUGS.DEVELOPER
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.CUSTOM_JSON,
|
||||
label: 'Custom JSON',
|
||||
|
||||
@@ -2,10 +2,12 @@ import { ToolSource } from '$lib/enums/tools.enums';
|
||||
|
||||
export const TOOL_GROUP_LABELS = {
|
||||
[ToolSource.BUILTIN]: 'Built-in',
|
||||
[ToolSource.CUSTOM]: 'JSON Schema'
|
||||
[ToolSource.CUSTOM]: 'JSON Schema',
|
||||
[ToolSource.FRONTEND]: 'Browser'
|
||||
} as const;
|
||||
|
||||
export const TOOL_SERVER_LABELS = {
|
||||
[ToolSource.BUILTIN]: 'Built-in Tools',
|
||||
[ToolSource.CUSTOM]: 'Custom Tools'
|
||||
[ToolSource.CUSTOM]: 'Custom Tools',
|
||||
[ToolSource.FRONTEND]: 'Browser Tools'
|
||||
} as const;
|
||||
|
||||
@@ -54,7 +54,9 @@ export enum MCPContentType {
|
||||
* JSON Schema types used in MCP tool definitions
|
||||
*/
|
||||
export enum JsonSchemaType {
|
||||
OBJECT = 'object'
|
||||
OBJECT = 'object',
|
||||
STRING = 'string',
|
||||
NUMBER = 'number'
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
export enum ToolSource {
|
||||
BUILTIN = 'builtin',
|
||||
MCP = 'mcp',
|
||||
CUSTOM = 'custom'
|
||||
CUSTOM = 'custom',
|
||||
FRONTEND = 'frontend'
|
||||
}
|
||||
|
||||
export enum ToolPermissionDecision {
|
||||
|
||||
@@ -344,6 +344,22 @@ export class DatabaseService {
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* Toggles the pinned status of a conversation.
|
||||
*
|
||||
* @param id - Conversation ID
|
||||
* @returns The new pinned status
|
||||
*/
|
||||
static async toggleConversationPin(id: string): Promise<boolean> {
|
||||
const conversation = await db.conversations.get(id);
|
||||
if (!conversation) {
|
||||
throw new Error(`Conversation ${id} not found`);
|
||||
}
|
||||
const newPinnedState = !conversation.pinned;
|
||||
await this.updateConversation(id, { pinned: newPinnedState });
|
||||
return newPinnedState;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the conversation's current node (active branch).
|
||||
* This determines which conversation path is currently being viewed.
|
||||
|
||||
@@ -261,6 +261,26 @@ export { ParameterSyncService } from './parameter-sync.service';
|
||||
*/
|
||||
export { MCPService } from './mcp.service';
|
||||
|
||||
/**
|
||||
* **SandboxService** - Frontend JavaScript execution in a browser sandbox
|
||||
*
|
||||
* Stateless executor for the run_javascript frontend tool. Model generated
|
||||
* code runs in a Web Worker spawned inside a sandboxed iframe with an opaque
|
||||
* origin: no access to the app origin, its storage or its API, and outgoing
|
||||
* requests carry a null origin. The code never touches a main thread, so the
|
||||
* parent enforces the timeout by removing the iframe, which terminates the
|
||||
* worker at the browser level.
|
||||
*
|
||||
* **Architecture & Relationships:**
|
||||
* - **SandboxService** (this class): Stateless sandbox execution
|
||||
* - **toolsStore**: Exposes the tool definition when the sandbox is enabled
|
||||
* - **agenticStore**: Dispatches ToolSource.FRONTEND calls here
|
||||
*
|
||||
* @see SANDBOX_TOOL_DEFINITION in constants/sandbox.ts - tool schema sent to the LLM
|
||||
* @see agenticStore in stores/agentic.svelte.ts - tool dispatch
|
||||
*/
|
||||
export { SandboxService } from './sandbox.service';
|
||||
|
||||
/**
|
||||
* **RouterService** — Dynamic route URL construction utility
|
||||
*
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import WORKER_SHIM from './sandbox-worker.js?raw';
|
||||
|
||||
/**
|
||||
* Harness loaded as srcdoc into a sandboxed iframe (allow-scripts only).
|
||||
* The opaque origin is the security boundary: no access to the app origin,
|
||||
* its storage or its API. The harness spawns a worker so model code never
|
||||
* runs on a main thread, which makes the parent timeout enforceable by
|
||||
* removing the iframe.
|
||||
*/
|
||||
export const SANDBOX_HARNESS_HTML = `<!doctype html><script>
|
||||
const SHIM = ${JSON.stringify(WORKER_SHIM)};
|
||||
addEventListener('message', (event) => {
|
||||
const respond = (payload) => parent.postMessage(payload, '*');
|
||||
let worker;
|
||||
try {
|
||||
worker = new Worker(URL.createObjectURL(new Blob([SHIM], { type: 'text/javascript' })));
|
||||
} catch (err) {
|
||||
respond({ logs: [], result: null, error: 'Worker creation failed: ' + err });
|
||||
return;
|
||||
}
|
||||
worker.onmessage = (msg) => respond(msg.data);
|
||||
worker.onerror = (err) => respond({ logs: [], result: null, error: String(err.message || err) });
|
||||
worker.postMessage({ code: event.data.code });
|
||||
});
|
||||
</script>`;
|
||||
@@ -0,0 +1,30 @@
|
||||
const logs = [];
|
||||
const fmt = (value) => {
|
||||
if (typeof value === 'string') return value;
|
||||
try {
|
||||
return JSON.stringify(value);
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
};
|
||||
const capture =
|
||||
(level, prefix) =>
|
||||
(...args) => {
|
||||
logs.push(prefix + args.map(fmt).join(' '));
|
||||
};
|
||||
console.log = capture('log', '');
|
||||
console.info = capture('info', '');
|
||||
console.debug = capture('debug', '');
|
||||
console.warn = capture('warn', 'warn: ');
|
||||
console.error = capture('error', 'error: ');
|
||||
self.onmessage = async (event) => {
|
||||
const reply = { logs, result: null, error: null };
|
||||
try {
|
||||
const AsyncFunction = Object.getPrototypeOf(async function () {}).constructor;
|
||||
const value = await new AsyncFunction(event.data.code)();
|
||||
if (value !== undefined) reply.result = fmt(value);
|
||||
} catch (err) {
|
||||
reply.error = err instanceof Error ? err.stack || err.message : String(err);
|
||||
}
|
||||
self.postMessage(reply);
|
||||
};
|
||||
@@ -0,0 +1,112 @@
|
||||
import {
|
||||
NEWLINE_SEPARATOR,
|
||||
SANDBOX_EMPTY_OUTPUT,
|
||||
SANDBOX_OUTPUT_MAX_CHARS,
|
||||
SANDBOX_TIMEOUT_MS_DEFAULT,
|
||||
SANDBOX_TIMEOUT_MS_MAX,
|
||||
SANDBOX_TOOL_NAME,
|
||||
SANDBOX_TRUNCATION_NOTICE
|
||||
} from '$lib/constants';
|
||||
import { SANDBOX_HARNESS_HTML } from './sandbox-harness';
|
||||
import type { ToolExecutionResult } from '$lib/types';
|
||||
|
||||
interface SandboxReply {
|
||||
logs?: unknown;
|
||||
result?: unknown;
|
||||
error?: unknown;
|
||||
}
|
||||
|
||||
function formatReply(reply: SandboxReply): ToolExecutionResult {
|
||||
const lines: string[] = [];
|
||||
|
||||
if (Array.isArray(reply.logs)) {
|
||||
for (const line of reply.logs) lines.push(String(line));
|
||||
}
|
||||
|
||||
if (reply.error != null) {
|
||||
lines.push(`Error: ${String(reply.error)}`);
|
||||
} else if (reply.result != null) {
|
||||
lines.push(`=> ${String(reply.result)}`);
|
||||
}
|
||||
|
||||
let content = lines.join(NEWLINE_SEPARATOR);
|
||||
if (!content) content = SANDBOX_EMPTY_OUTPUT;
|
||||
if (content.length > SANDBOX_OUTPUT_MAX_CHARS) {
|
||||
content = `${content.slice(0, SANDBOX_OUTPUT_MAX_CHARS)}${NEWLINE_SEPARATOR}${SANDBOX_TRUNCATION_NOTICE}`;
|
||||
}
|
||||
|
||||
return { content, isError: reply.error != null };
|
||||
}
|
||||
|
||||
export class SandboxService {
|
||||
/**
|
||||
* Execute a frontend sandbox tool call and return its output.
|
||||
* One disposable iframe per execution, removed on completion,
|
||||
* timeout or abort. Removing the iframe terminates the worker
|
||||
* at the browser level, so runaway code cannot outlive it.
|
||||
*/
|
||||
static executeTool(
|
||||
toolName: string,
|
||||
params: Record<string, unknown>,
|
||||
signal?: AbortSignal
|
||||
): Promise<ToolExecutionResult> {
|
||||
if (toolName !== SANDBOX_TOOL_NAME) {
|
||||
return Promise.resolve({ content: `Unknown frontend tool: ${toolName}`, isError: true });
|
||||
}
|
||||
|
||||
const code = typeof params.code === 'string' ? params.code : '';
|
||||
if (!code) {
|
||||
return Promise.resolve({ content: 'Missing required parameter: code', isError: true });
|
||||
}
|
||||
|
||||
const requested = Number(params.timeout_ms);
|
||||
const timeoutMs =
|
||||
Number.isFinite(requested) && requested > 0
|
||||
? Math.min(requested, SANDBOX_TIMEOUT_MS_MAX)
|
||||
: SANDBOX_TIMEOUT_MS_DEFAULT;
|
||||
|
||||
return new Promise<ToolExecutionResult>((resolve, reject) => {
|
||||
const iframe = document.createElement('iframe');
|
||||
iframe.setAttribute('sandbox', 'allow-scripts');
|
||||
iframe.style.display = 'none';
|
||||
iframe.srcdoc = SANDBOX_HARNESS_HTML;
|
||||
|
||||
let settled = false;
|
||||
|
||||
const cleanup = () => {
|
||||
settled = true;
|
||||
clearTimeout(timer);
|
||||
window.removeEventListener('message', onMessage);
|
||||
signal?.removeEventListener('abort', onAbort);
|
||||
iframe.remove();
|
||||
};
|
||||
|
||||
const finish = (result: ToolExecutionResult) => {
|
||||
if (settled) return;
|
||||
cleanup();
|
||||
resolve(result);
|
||||
};
|
||||
|
||||
const onAbort = () => {
|
||||
if (settled) return;
|
||||
cleanup();
|
||||
reject(new DOMException('Sandbox execution aborted', 'AbortError'));
|
||||
};
|
||||
|
||||
const onMessage = (event: MessageEvent) => {
|
||||
if (event.source !== iframe.contentWindow) return;
|
||||
finish(formatReply((event.data ?? {}) as SandboxReply));
|
||||
};
|
||||
|
||||
const timer = setTimeout(
|
||||
() => finish({ content: `Execution timed out after ${timeoutMs} ms`, isError: true }),
|
||||
timeoutMs
|
||||
);
|
||||
|
||||
window.addEventListener('message', onMessage);
|
||||
signal?.addEventListener('abort', onAbort);
|
||||
iframe.onload = () => iframe.contentWindow?.postMessage({ code }, '*');
|
||||
document.body.appendChild(iframe);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,7 @@ import { permissionsStore } from '$lib/stores/permissions.svelte';
|
||||
import { ToolSource, ToolPermissionDecision } from '$lib/enums';
|
||||
import { SvelteMap } from 'svelte/reactivity';
|
||||
import { ToolsService } from '$lib/services/tools.service';
|
||||
import { SandboxService } from '$lib/services/sandbox.service';
|
||||
import { isAbortError } from '$lib/utils';
|
||||
import { DEFAULT_AGENTIC_CONFIG, NEWLINE_SEPARATOR } from '$lib/constants';
|
||||
import {
|
||||
@@ -784,6 +785,13 @@ class AgenticStore {
|
||||
|
||||
result = executionResult.content;
|
||||
|
||||
if (executionResult.isError) toolSuccess = false;
|
||||
} else if (toolSource === ToolSource.FRONTEND) {
|
||||
const args = this.parseToolArguments(toolCall.function.arguments);
|
||||
const executionResult = await SandboxService.executeTool(toolName, args, signal);
|
||||
|
||||
result = executionResult.content;
|
||||
|
||||
if (executionResult.isError) toolSuccess = false;
|
||||
} else {
|
||||
const mcpCall: MCPToolCall = {
|
||||
|
||||
@@ -506,6 +506,33 @@ class ConversationsStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggles the pinned status of a conversation.
|
||||
* @param convId - The conversation ID to toggle
|
||||
* @returns The new pinned status
|
||||
*/
|
||||
async toggleConversationPin(convId: string): Promise<boolean> {
|
||||
try {
|
||||
const newPinnedState = await DatabaseService.toggleConversationPin(convId);
|
||||
|
||||
const convIndex = this.conversations.findIndex((c) => c.id === convId);
|
||||
|
||||
if (convIndex !== -1) {
|
||||
this.conversations[convIndex].pinned = newPinnedState;
|
||||
this.conversations = [...this.conversations];
|
||||
}
|
||||
|
||||
if (this.activeConversation?.id === convId) {
|
||||
this.activeConversation = { ...this.activeConversation, pinned: newPinnedState };
|
||||
}
|
||||
|
||||
return newPinnedState;
|
||||
} catch (error) {
|
||||
console.error('Failed to toggle conversation pin:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates conversation title with optional confirmation dialog based on settings
|
||||
* @param convId - The conversation ID to update
|
||||
@@ -1057,6 +1084,14 @@ export const isConversationsInitialized = () => conversationsStore.isInitialized
|
||||
* Builds a flat tree of conversations with depth levels for nested forks.
|
||||
* Accepts a pre-filtered list so search filtering stays in the component.
|
||||
*/
|
||||
|
||||
// Pinned conversations first, then by lastModified descending
|
||||
const comparePinnedThenRecent = (a: DatabaseConversation, b: DatabaseConversation) => {
|
||||
if (a.pinned && !b.pinned) return -1;
|
||||
if (!a.pinned && b.pinned) return 1;
|
||||
return b.lastModified - a.lastModified;
|
||||
};
|
||||
|
||||
export function buildConversationTree(convs: DatabaseConversation[]): ConversationTreeItem[] {
|
||||
const childrenByParent = new SvelteMap<string, DatabaseConversation[]>();
|
||||
const forkIds = new SvelteSet<string>();
|
||||
@@ -1081,7 +1116,7 @@ export function buildConversationTree(convs: DatabaseConversation[]): Conversati
|
||||
|
||||
const children = childrenByParent.get(conv.id);
|
||||
if (children) {
|
||||
children.sort((a, b) => b.lastModified - a.lastModified);
|
||||
children.sort(comparePinnedThenRecent);
|
||||
|
||||
for (const child of children) {
|
||||
walk(child, depth + 1);
|
||||
@@ -1089,7 +1124,7 @@ export function buildConversationTree(convs: DatabaseConversation[]): Conversati
|
||||
}
|
||||
}
|
||||
|
||||
const roots = convs.filter((c) => !forkIds.has(c.id));
|
||||
const roots = convs.filter((c) => !forkIds.has(c.id)).sort(comparePinnedThenRecent);
|
||||
for (const root of roots) {
|
||||
walk(root, 0);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { HealthCheckStatus, JsonSchemaType, ToolCallType, ToolSource } from '$li
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import {
|
||||
DISABLED_TOOL_KEYS_LOCALSTORAGE_KEY,
|
||||
SANDBOX_TOOL_DEFINITION,
|
||||
TOOL_GROUP_LABELS,
|
||||
TOOL_SERVER_LABELS
|
||||
} from '$lib/constants';
|
||||
@@ -18,6 +19,8 @@ function toolKey(source: ToolSource, name: string, serverId?: string): string {
|
||||
return serverId ? `mcp-${serverId}:${name}` : `mcp:${name}`;
|
||||
case ToolSource.CUSTOM:
|
||||
return `custom:${name}`;
|
||||
case ToolSource.FRONTEND:
|
||||
return `frontend:${name}`;
|
||||
default:
|
||||
return `builtin:${name}`;
|
||||
}
|
||||
@@ -82,6 +85,10 @@ class ToolsStore {
|
||||
return mcpStore.getToolDefinitionsForLLM();
|
||||
}
|
||||
|
||||
get frontendTools(): OpenAIToolDefinition[] {
|
||||
return config().jsSandboxEnabled ? [SANDBOX_TOOL_DEFINITION] : [];
|
||||
}
|
||||
|
||||
get customTools(): OpenAIToolDefinition[] {
|
||||
const raw = config().customJson;
|
||||
if (!raw || typeof raw !== 'string') return [];
|
||||
@@ -156,6 +163,15 @@ class ToolsStore {
|
||||
push({ source: ToolSource.BUILTIN, key: toolKey(ToolSource.BUILTIN, name), definition: def });
|
||||
}
|
||||
|
||||
for (const def of this.frontendTools) {
|
||||
const name = def.function.name;
|
||||
push({
|
||||
source: ToolSource.FRONTEND,
|
||||
key: toolKey(ToolSource.FRONTEND, name),
|
||||
definition: def
|
||||
});
|
||||
}
|
||||
|
||||
for (const { serverId, serverName, definition } of this.mcpEntries()) {
|
||||
const name = definition.function.name;
|
||||
push({
|
||||
@@ -208,6 +224,8 @@ class ToolsStore {
|
||||
return entry.serverName ?? '';
|
||||
case ToolSource.CUSTOM:
|
||||
return TOOL_GROUP_LABELS[ToolSource.CUSTOM];
|
||||
case ToolSource.FRONTEND:
|
||||
return TOOL_GROUP_LABELS[ToolSource.FRONTEND];
|
||||
default:
|
||||
return TOOL_GROUP_LABELS[ToolSource.BUILTIN];
|
||||
}
|
||||
@@ -237,6 +255,7 @@ class ToolsStore {
|
||||
};
|
||||
|
||||
for (const def of this._builtinTools) take(def);
|
||||
for (const def of this.frontendTools) take(def);
|
||||
for (const def of mcpStore.getToolDefinitionsForLLM()) take(def);
|
||||
for (const def of this.customTools) take(def);
|
||||
|
||||
@@ -346,6 +365,7 @@ class ToolsStore {
|
||||
if (entry.serverName) return mcpStore.getServerDisplayName(entry.serverName);
|
||||
if (entry.source === ToolSource.BUILTIN) return TOOL_SERVER_LABELS[ToolSource.BUILTIN];
|
||||
if (entry.source === ToolSource.CUSTOM) return TOOL_SERVER_LABELS[ToolSource.CUSTOM];
|
||||
if (entry.source === ToolSource.FRONTEND) return TOOL_SERVER_LABELS[ToolSource.FRONTEND];
|
||||
return '';
|
||||
}
|
||||
|
||||
|
||||
Vendored
+1
@@ -15,6 +15,7 @@ export interface DatabaseConversation {
|
||||
thinkingEnabled?: boolean;
|
||||
reasoningEffort?: ReasoningEffort;
|
||||
forkedFromConversationId?: string;
|
||||
pinned?: boolean;
|
||||
}
|
||||
|
||||
export interface DatabaseMessageExtraAudioFile {
|
||||
|
||||
@@ -24,7 +24,8 @@
|
||||
"tests/**/*.svelte",
|
||||
".storybook/**/*.ts",
|
||||
".storybook/**/*.svelte"
|
||||
]
|
||||
],
|
||||
"exclude": ["src/lib/services/sandbox-worker.js"]
|
||||
// Path aliases are handled by https://svelte.dev/docs/kit/configuration#alias
|
||||
// except $lib which is handled by https://svelte.dev/docs/kit/configuration#files
|
||||
//
|
||||
|
||||
Vendored
+1
-1
@@ -81,7 +81,7 @@ if (LLAMA_BUILD_BORINGSSL)
|
||||
target_link_libraries(${TARGET} PUBLIC ssl crypto)
|
||||
|
||||
elseif (LLAMA_BUILD_LIBRESSL)
|
||||
set(LIBRESSL_VERSION "4.3.1" CACHE STRING "LibreSSL version")
|
||||
set(LIBRESSL_VERSION "4.3.2" CACHE STRING "LibreSSL version")
|
||||
|
||||
message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user