mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 35715657cb | |||
| f75c4e8bf5 | |||
| 99156f3a5f | |||
| a0c91e8f9f | |||
| 07968d53e4 | |||
| ba3b9c8844 | |||
| 94b0200a01 | |||
| b908baf182 |
@@ -1,8 +1,8 @@
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.0
|
||||
ARG AMDGPU_VERSION=7.0
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
@@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
|
||||
#ARG ROCM_DOCKER_ARCH='gfx1151'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
||||
@@ -516,6 +516,102 @@ jobs:
|
||||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
|
||||
ubuntu-22-rocm:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \
|
||||
gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null
|
||||
|
||||
sudo tee /etc/apt/sources.list.d/rocm.list << EOF
|
||||
deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main
|
||||
EOF
|
||||
|
||||
sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF
|
||||
Package: *
|
||||
Pin: release o=repo.radeon.com
|
||||
Pin-Priority: 600
|
||||
EOF
|
||||
|
||||
sudo apt update
|
||||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
mkdir install
|
||||
tar -xf *.tar.gz -C install
|
||||
export ROCM_PATH=$(pwd)/install
|
||||
echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV
|
||||
echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV
|
||||
echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV
|
||||
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGPU_TARGETS="${{ matrix.gpu_targets }}" \
|
||||
-DGGML_HIP=ON \
|
||||
-DHIP_PLATFORM=amd \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Determine tag name
|
||||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
|
||||
@@ -784,6 +880,7 @@ jobs:
|
||||
- windows-cuda
|
||||
- windows-sycl
|
||||
- windows-hip
|
||||
- ubuntu-22-rocm
|
||||
- ubuntu-22-cpu
|
||||
- ubuntu-22-vulkan
|
||||
- macOS-arm64
|
||||
@@ -868,6 +965,7 @@ jobs:
|
||||
**Linux:**
|
||||
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
|
||||
- [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz)
|
||||
- [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-x64.tar.gz)
|
||||
- [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz)
|
||||
|
||||
**Windows:**
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||
cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
|
||||
project("llama.cpp" C CXX)
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
|
||||
@@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
|
||||
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
|
||||
}
|
||||
|
||||
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
form.scope_start = "<tool_call>";
|
||||
form.tool_start = "<function=";
|
||||
form.tool_sep = ">";
|
||||
form.key_start = "<parameter=";
|
||||
form.key_val_sep = ">";
|
||||
form.val_end = "</parameter>";
|
||||
form.tool_end = "</function>";
|
||||
form.scope_end = "</tool_call>";
|
||||
form.trim_raw_argval = true;
|
||||
return form;
|
||||
})();
|
||||
builder.consume_reasoning_with_xml_tool_calls(form);
|
||||
}
|
||||
|
||||
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
|
||||
static const xml_tool_call_format form = ([]() {
|
||||
xml_tool_call_format form {};
|
||||
@@ -1590,9 +1573,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2:
|
||||
common_chat_parse_kimi_k2(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
|
||||
common_chat_parse_qwen3_coder_xml(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5:
|
||||
common_chat_parse_apriel_1_5(builder);
|
||||
break;
|
||||
|
||||
+11
-45
@@ -736,7 +736,6 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2";
|
||||
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
|
||||
case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2";
|
||||
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
|
||||
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
|
||||
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
|
||||
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
|
||||
@@ -1522,14 +1521,17 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
|
||||
|
||||
// Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking
|
||||
bool supports_reasoning = (tmpl.source().find("<think>") != std::string::npos);
|
||||
|
||||
// Handle thinking tags appropriately based on inputs.enable_thinking
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (supports_reasoning && string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
@@ -1538,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_
|
||||
}
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
};
|
||||
|
||||
if (supports_reasoning) {
|
||||
data.preserved_tokens.insert(data.preserved_tokens.end(), {"<think>", "</think>"});
|
||||
}
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = true;
|
||||
|
||||
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
|
||||
auto reasoning = p.eps();
|
||||
if (inputs.enable_thinking && extract_reasoning) {
|
||||
if (supports_reasoning && inputs.enable_thinking && extract_reasoning) {
|
||||
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
|
||||
if (data.thinking_forced_open) {
|
||||
reasoning = reasoning_content;
|
||||
@@ -1888,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
data.prompt = apply(tmpl, params);
|
||||
data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML;
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<tool_call>",
|
||||
"</tool_call>",
|
||||
"<function=",
|
||||
"</function>",
|
||||
"<parameter=",
|
||||
"</parameter>",
|
||||
};
|
||||
|
||||
// build grammar for tool call
|
||||
static const xml_tool_call_format form {
|
||||
/* form.scope_start = */ "<tool_call>\n",
|
||||
/* form.tool_start = */ "<function=",
|
||||
/* form.tool_sep = */ ">\n",
|
||||
/* form.key_start = */ "<parameter=",
|
||||
/* form.key_val_sep = */ ">\n",
|
||||
/* form.val_end = */ "\n</parameter>\n",
|
||||
/* form.tool_end = */ "</function>\n",
|
||||
/* form.scope_end = */ "</tool_call>",
|
||||
};
|
||||
build_grammar_xml_tool_call(data, params.tools, form);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) {
|
||||
common_chat_params data;
|
||||
data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
@@ -3147,13 +3119,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
src.find("<function=") != std::string::npos &&
|
||||
src.find("<parameter=") != std::string::npos) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
// Models with <think> support (Step-3.5-Flash, Nemotron 3 Nano) use the
|
||||
// Nemotron v3 PEG parser for streaming and schema-aware parameter parsing.
|
||||
// Qwen3-Coder has no <think> in its template.
|
||||
if (src.find("<think>") != std::string::npos) {
|
||||
return common_chat_params_init_nemotron_v3(tmpl, params);
|
||||
}
|
||||
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
|
||||
return common_chat_params_init_qwen3_coder(tmpl, params);
|
||||
}
|
||||
|
||||
// Xiaomi MiMo format detection (must come before Hermes 2 Pro)
|
||||
|
||||
@@ -128,7 +128,6 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_GLM_4_5,
|
||||
COMMON_CHAT_FORMAT_MINIMAX_M2,
|
||||
COMMON_CHAT_FORMAT_KIMI_K2,
|
||||
COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
|
||||
COMMON_CHAT_FORMAT_APRIEL_1_5,
|
||||
COMMON_CHAT_FORMAT_XIAOMI_MIMO,
|
||||
COMMON_CHAT_FORMAT_SOLAR_OPEN,
|
||||
|
||||
@@ -171,15 +171,9 @@
|
||||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
|
||||
@@ -1954,3 +1954,773 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
||||
#endif
|
||||
}
|
||||
|
||||
static const uint8_t sign_gather_indices_arr[64] = {
|
||||
0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
|
||||
4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
|
||||
};
|
||||
|
||||
static const uint8_t sign_bit_masks_arr[64] = {
|
||||
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,
|
||||
1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128
|
||||
};
|
||||
|
||||
static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
||||
|
||||
const block_iq2_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
||||
|
||||
// --- Pre-load Constants ---
|
||||
uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};
|
||||
vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);
|
||||
uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};
|
||||
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);
|
||||
|
||||
// Constants for sign extraction
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
||||
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const uint8_t * signs_ptr = qs + 32;
|
||||
|
||||
float sum_block = 0.0f;
|
||||
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
// Combine low + high bits
|
||||
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);
|
||||
qs += 8;
|
||||
uint16_t qh_val;
|
||||
memcpy(&qh_val, qh, 2);
|
||||
qh += 2;
|
||||
vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);
|
||||
vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);
|
||||
vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);
|
||||
vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);
|
||||
v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);
|
||||
|
||||
// Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000
|
||||
v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);
|
||||
vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);
|
||||
|
||||
// Multiply by 8 to get byte offset, instead of element offset
|
||||
v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);
|
||||
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);
|
||||
|
||||
// Lookup Grid using Byte Offsets
|
||||
vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);
|
||||
|
||||
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);
|
||||
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);
|
||||
|
||||
// Load signs and generate sign mask
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);
|
||||
signs_ptr += 8;
|
||||
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
||||
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
||||
q8 += 64;
|
||||
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
||||
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);
|
||||
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));
|
||||
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));
|
||||
int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));
|
||||
int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));
|
||||
|
||||
uint8_t sc0 = scales[0];
|
||||
uint8_t sc1 = scales[1];
|
||||
scales += 2;
|
||||
|
||||
sum_block += s0 * (2 * (sc0 & 0xF) + 1);
|
||||
sum_block += s1 * (2 * (sc0 >> 4) + 1);
|
||||
sum_block += s2 * (2 * (sc1 & 0xF) + 1);
|
||||
sum_block += s3 * (2 * (sc1 >> 4) + 1);
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
|
||||
|
||||
const block_iq2_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
|
||||
|
||||
// Pre-load Constants
|
||||
vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);
|
||||
vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);
|
||||
vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);
|
||||
vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);
|
||||
uint16_t shift_qh_arr[4] = {11, 9, 7, 5};
|
||||
vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
const uint8_t * signs_ptr = qs + 32;
|
||||
float sum_block = 0.0f;
|
||||
|
||||
for (int ib = 0; ib < 8; ++ib) {
|
||||
|
||||
// Load Low Bits [4 bytes]
|
||||
vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);
|
||||
qs += 4;
|
||||
|
||||
// Load 1 byte. It contains bits for 4 mini-blocks.
|
||||
uint8_t qh_val = *qh++;
|
||||
|
||||
// Combine Low + High bits of 10bit indices
|
||||
vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);
|
||||
vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);
|
||||
vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);
|
||||
v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);
|
||||
vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);
|
||||
vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);
|
||||
vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);
|
||||
|
||||
// Lookup Grid
|
||||
vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));
|
||||
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);
|
||||
signs_ptr += 4;
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);
|
||||
|
||||
// generating sign mask
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);
|
||||
q8 += 32;
|
||||
|
||||
// apply signs
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);
|
||||
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);
|
||||
|
||||
// Reduction
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
// Reduce 0-15 (First Half)
|
||||
int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));
|
||||
|
||||
// Reduce 16-31 (Second Half)
|
||||
int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
|
||||
__riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));
|
||||
|
||||
// Apply sub Scales
|
||||
uint8_t sc = *scales++;
|
||||
|
||||
sum_block += s0 * (2 * (sc & 0xF) + 1);
|
||||
sum_block += s1 * (2 * (sc >> 4) + 1);
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 128:
|
||||
ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
case 256:
|
||||
ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq3_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
|
||||
|
||||
// --- Pre-load Constants ---
|
||||
const uint16_t qh_bit_shifts_arr[16] = {
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
|
||||
};
|
||||
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
|
||||
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
|
||||
vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
const float combined_scale = d * y[i].d;
|
||||
|
||||
const uint8_t * GGML_RESTRICT qs = x[i].qs;
|
||||
const uint8_t * GGML_RESTRICT qh = x[i].qh;
|
||||
const uint8_t * GGML_RESTRICT scales = x[i].scales;
|
||||
const uint8_t * GGML_RESTRICT signs = x[i].signs;
|
||||
const int8_t * GGML_RESTRICT q8 = y[i].qs;
|
||||
|
||||
float sum_block = 0.0f;
|
||||
|
||||
// Loop: Process 64 weights (16 mini-blocks of 4) per iteration
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
|
||||
vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
|
||||
qs += 16;
|
||||
|
||||
uint16_t qh_val;
|
||||
memcpy(&qh_val, qh, 2);
|
||||
qh += 2;
|
||||
|
||||
vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
|
||||
// Extract bits: (qh >> i) & 1
|
||||
v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
|
||||
v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
|
||||
|
||||
vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
|
||||
v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
|
||||
v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
|
||||
vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
|
||||
|
||||
// Grid value is 4xuint8
|
||||
vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
|
||||
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
|
||||
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
|
||||
signs += 8;
|
||||
|
||||
// Generate sign mask
|
||||
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
|
||||
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
|
||||
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
|
||||
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
|
||||
|
||||
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
|
||||
q8 += 64;
|
||||
|
||||
// Apply Signs
|
||||
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
|
||||
vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
|
||||
|
||||
// Reduction
|
||||
vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
|
||||
vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
|
||||
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
|
||||
int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
|
||||
|
||||
// Apply sub-scales
|
||||
uint8_t sc_byte = *scales++;
|
||||
int sc_lo = (sc_byte & 0xF) * 2 + 1;
|
||||
int sc_hi = (sc_byte >> 4) * 2 + 1;
|
||||
|
||||
sum_block += s_lo * sc_lo + s_hi * sc_hi;
|
||||
}
|
||||
sumf += sum_block * combined_scale;
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_tq1_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0.0f;
|
||||
uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// First loop.
|
||||
vint32m4_t suml1;
|
||||
{
|
||||
const int vl = 32;
|
||||
vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);
|
||||
|
||||
vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);
|
||||
vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);
|
||||
vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);
|
||||
vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);
|
||||
vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);
|
||||
vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);
|
||||
vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);
|
||||
|
||||
vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);
|
||||
vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);
|
||||
vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);
|
||||
vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);
|
||||
vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);
|
||||
|
||||
vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);
|
||||
vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);
|
||||
suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);
|
||||
}
|
||||
|
||||
// Second loop.
|
||||
vint32m2_t suml2;
|
||||
{
|
||||
const int vl = 16;
|
||||
vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);
|
||||
|
||||
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);
|
||||
vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);
|
||||
vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);
|
||||
vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);
|
||||
vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);
|
||||
vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);
|
||||
vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);
|
||||
|
||||
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
||||
vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);
|
||||
vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);
|
||||
vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);
|
||||
vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);
|
||||
|
||||
vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);
|
||||
vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);
|
||||
suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);
|
||||
}
|
||||
|
||||
// Third loop.
|
||||
vint32m2_t suml3;
|
||||
{
|
||||
const int vl = 16;
|
||||
|
||||
uint32_t qh;
|
||||
memcpy(&qh, &x[i].qh[0], 4);
|
||||
// Prevent fusion with vmv.
|
||||
__asm__ __volatile__("" : "+r"(qh));
|
||||
vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));
|
||||
|
||||
vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);
|
||||
|
||||
vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);
|
||||
|
||||
vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);
|
||||
|
||||
vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
|
||||
suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);
|
||||
}
|
||||
|
||||
vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);
|
||||
sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);
|
||||
sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);
|
||||
|
||||
vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
|
||||
sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_tq2_0 * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
int32_t sumi = 0;
|
||||
|
||||
for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {
|
||||
const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];
|
||||
const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];
|
||||
const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];
|
||||
const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];
|
||||
const uint8_t* px = &x[i].qs[j];
|
||||
|
||||
size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
||||
vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);
|
||||
|
||||
size_t vl = __riscv_vsetvl_e8m1(32);
|
||||
|
||||
vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);
|
||||
|
||||
vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);
|
||||
vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);
|
||||
vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);
|
||||
vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);
|
||||
|
||||
// l=0 (bits 1:0)
|
||||
vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);
|
||||
vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);
|
||||
|
||||
// l=1 (bits 3:2)
|
||||
vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);
|
||||
vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);
|
||||
|
||||
// l=2 (bits 5:4)
|
||||
vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);
|
||||
vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);
|
||||
|
||||
// l=3 (bits 7:6)
|
||||
vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros
|
||||
vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);
|
||||
|
||||
// 4. Multiply and accumulate
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);
|
||||
vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);
|
||||
|
||||
vlmax_16m2 = __riscv_vsetvl_e16m2(32);
|
||||
vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);
|
||||
|
||||
sumi += __riscv_vmv_x_s_i32m1_i32(vred32);
|
||||
}
|
||||
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
|
||||
sumf += (float)sumi * d;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq1_s * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// Load qh once for the entire superblock.
|
||||
vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);
|
||||
|
||||
// Calculate ls.
|
||||
vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);
|
||||
temp = __riscv_vand_vx_u16mf2(temp, 7, 8);
|
||||
vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));
|
||||
ls = __riscv_vadd_vx_i32m1(ls, 1, 8);
|
||||
|
||||
// Calculate delta.
|
||||
vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);
|
||||
vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);
|
||||
vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);
|
||||
vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);
|
||||
|
||||
// Load qs.
|
||||
vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);
|
||||
|
||||
// Prepare the indices.
|
||||
const uint64_t shift = 0x0009000600030000;
|
||||
vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));
|
||||
vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(
|
||||
__riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));
|
||||
vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));
|
||||
vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);
|
||||
qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);
|
||||
qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);
|
||||
qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);
|
||||
qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);
|
||||
vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);
|
||||
|
||||
// Final lsums.
|
||||
int32_t lsums_s[8];
|
||||
vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
|
||||
// Sub-blocks 1-4
|
||||
{
|
||||
vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);
|
||||
vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));
|
||||
vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);
|
||||
vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);
|
||||
lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));
|
||||
lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));
|
||||
lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));
|
||||
lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
// Sub-blocks 5-8
|
||||
{
|
||||
vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);
|
||||
vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));
|
||||
vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);
|
||||
vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);
|
||||
lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));
|
||||
lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));
|
||||
lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));
|
||||
lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);
|
||||
|
||||
// Calculate the bsums.
|
||||
vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);
|
||||
const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));
|
||||
const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));
|
||||
const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));
|
||||
const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);
|
||||
|
||||
// Accumulation.
|
||||
vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);
|
||||
vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);
|
||||
|
||||
// Update sumf.
|
||||
int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
||||
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
|
||||
sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_iq1_m * GGML_RESTRICT x = vx;
|
||||
const block_q8_K * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
float sumf = 0.0f;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
|
||||
// Accumulators.
|
||||
vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);
|
||||
vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);
|
||||
|
||||
// We process 4 sub-blocks together.
|
||||
for (int ib = 0; ib < QK_K/128; ib++) {
|
||||
// Load qh for 4 sub-blocks.
|
||||
const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);
|
||||
const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);
|
||||
const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);
|
||||
const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(
|
||||
__riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);
|
||||
qh += 8;
|
||||
|
||||
// Prepare grid indices.
|
||||
const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);
|
||||
const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));
|
||||
vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);
|
||||
index = __riscv_vsll_vx_u16m1(index, 3, 16);
|
||||
qs += 16;
|
||||
|
||||
// Load the grid.
|
||||
const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(
|
||||
__riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));
|
||||
|
||||
// Prepare the deltas.
|
||||
const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(
|
||||
__riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);
|
||||
const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);
|
||||
const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);
|
||||
const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(
|
||||
__riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));
|
||||
|
||||
// Load q8 for sub-blocks.
|
||||
const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
|
||||
q8 += 128;
|
||||
|
||||
// Calculate the lsums.
|
||||
const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);
|
||||
const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);
|
||||
|
||||
// Prepare the scales.
|
||||
const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;
|
||||
const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;
|
||||
const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;
|
||||
const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;
|
||||
const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;
|
||||
const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;
|
||||
const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;
|
||||
const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;
|
||||
sc += 2;
|
||||
|
||||
// Accumulate in acc0 and acc1 for each sub-block.
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);
|
||||
//
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);
|
||||
acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);
|
||||
acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);
|
||||
}
|
||||
|
||||
// Reduce and accumulate in `sumf`.
|
||||
vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);
|
||||
int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));
|
||||
int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));
|
||||
sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
#if defined __riscv_v_intrinsic
|
||||
switch (__riscv_vlenb() * 8) {
|
||||
case 256:
|
||||
ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
default:
|
||||
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
break;
|
||||
}
|
||||
#else
|
||||
ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1149,8 +1149,7 @@ struct ggml_cuda_graph {
|
||||
size_t num_nodes = 0;
|
||||
std::vector<cudaGraphNode_t> nodes;
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
int number_consecutive_updates = 0;
|
||||
bool warmup_complete = false;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
@@ -1159,21 +1158,9 @@ struct ggml_cuda_graph {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
} else {
|
||||
number_consecutive_updates = 0;
|
||||
}
|
||||
if (number_consecutive_updates >= 4) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
||||
disable_due_to_too_many_updates = true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_enabled() const {
|
||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
|
||||
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -2979,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
@@ -3931,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
|
||||
if (graph_compatible) {
|
||||
const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
if (!graph->warmup_complete) {
|
||||
// Warmup: need at least 2 calls with no property change on the 2nd call
|
||||
if (!properties_changed) {
|
||||
graph->warmup_complete = true;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = true;
|
||||
}
|
||||
// else: properties changed or first call - execute directly (use_cuda_graph stays false)
|
||||
} else {
|
||||
// Post-warmup: normal CUDA graph operation
|
||||
if (properties_changed) {
|
||||
// Properties changed - reset warmup, execute directly until stable again
|
||||
graph->warmup_complete = false;
|
||||
GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
|
||||
} else {
|
||||
use_cuda_graph = true;
|
||||
cuda_graph_update_required = graph->instance == nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.33.1"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
+143
-536
@@ -229,6 +229,20 @@ common_chat_tool python_tool {
|
||||
"required": ["code"]
|
||||
})",
|
||||
};
|
||||
common_chat_tool todo_list_tool {
|
||||
/* .name = */ "todo_list",
|
||||
/* .description = */ "Create or update the todo list",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"description": "List of TODO list items"
|
||||
}
|
||||
},
|
||||
"required": ["todos"]
|
||||
})",
|
||||
};
|
||||
common_chat_tool code_interpreter_tool {
|
||||
/* .name = */ "code_interpreter",
|
||||
/* .description = */ "an ipython interpreter",
|
||||
@@ -3018,542 +3032,6 @@ Hey there!<|im_end|>
|
||||
);
|
||||
}
|
||||
|
||||
// Test Qwen3-Coder XML format
|
||||
{
|
||||
// Basic XML tool call parsing
|
||||
assert_msg_equals(
|
||||
message_assist_call,
|
||||
test_chat_parse(
|
||||
"<tool_call>\n"
|
||||
" <function=special_function>\n"
|
||||
" <parameter=arg1>\n"
|
||||
" 1\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_QWEN3_CODER_XML}));
|
||||
|
||||
// Multiple parameters with different types
|
||||
common_chat_msg expected_multi_param;
|
||||
expected_multi_param.role = "assistant";
|
||||
expected_multi_param.tool_calls = {
|
||||
{ "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_multi_param,
|
||||
"<tool_call>\n"
|
||||
" <function=complex_function>\n"
|
||||
" <parameter=name>\n"
|
||||
" John Doe\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=age>\n"
|
||||
" 30\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=active>\n"
|
||||
" true\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=score>\n"
|
||||
" 95.5\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Special characters and Unicode
|
||||
common_chat_msg expected_special_chars;
|
||||
expected_special_chars.role = "assistant";
|
||||
expected_special_chars.tool_calls = {
|
||||
{ "unicode_function", "{\"message\":\"Hello 世界! 🌍 Special chars: @#$%^&*()\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_special_chars,
|
||||
"<tool_call>\n"
|
||||
" <function=unicode_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" Hello 世界! 🌍 Special chars: @#$%^&*()\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Multiline content with newlines and indentation
|
||||
common_chat_msg expected_multiline;
|
||||
expected_multiline.role = "assistant";
|
||||
expected_multiline.tool_calls = {
|
||||
{ "code_function", "{\"code\":\"def hello():\\n print(\\\"Hello, World!\\\")\\n return True\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(expected_multiline,
|
||||
"<tool_call>\n"
|
||||
" <function=code_function>\n"
|
||||
" <parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, World!\")\n"
|
||||
" return True\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// JSON object as parameter value
|
||||
common_chat_msg expected_json_param;
|
||||
expected_json_param.role = "assistant";
|
||||
expected_json_param.tool_calls = {
|
||||
{ "json_function", "{\"config\":{\"host\":\"localhost\",\"port\":8080,\"ssl\":false}}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_json_param,
|
||||
"<tool_call>\n"
|
||||
" <function=json_function>\n"
|
||||
" <parameter=config>\n"
|
||||
" {\"host\": \"localhost\", \"port\": 8080, \"ssl\": false}\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Array as parameter value
|
||||
common_chat_msg expected_array_param;
|
||||
expected_array_param.role = "assistant";
|
||||
expected_array_param.tool_calls = {
|
||||
{ "array_function", "{\"items\":[\"apple\",\"banana\",\"cherry\"]}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_array_param,
|
||||
"<tool_call>\n"
|
||||
" <function=array_function>\n"
|
||||
" <parameter=items>\n"
|
||||
" [\"apple\", \"banana\", \"cherry\"]\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Empty parameter
|
||||
common_chat_msg expected_empty_param;
|
||||
expected_empty_param.role = "assistant";
|
||||
expected_empty_param.tool_calls = {
|
||||
{ "empty_function", "{\"empty_param\":\"\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_empty_param,
|
||||
"<tool_call>\n"
|
||||
" <function=empty_function>\n"
|
||||
" <parameter=empty_param>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Boolean values (true/false)
|
||||
common_chat_msg expected_boolean;
|
||||
expected_boolean.role = "assistant";
|
||||
expected_boolean.tool_calls = {
|
||||
{ "boolean_function", "{\"enabled\":true,\"debug\":false}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_boolean,
|
||||
"<tool_call>\n"
|
||||
" <function=boolean_function>\n"
|
||||
" <parameter=enabled>\n"
|
||||
" true\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=debug>\n"
|
||||
" false\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Null value
|
||||
common_chat_msg expected_null;
|
||||
expected_null.role = "assistant";
|
||||
expected_null.tool_calls = {
|
||||
{ "null_function", "{\"optional_param\":null}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_null,
|
||||
"<tool_call>\n"
|
||||
" <function=null_function>\n"
|
||||
" <parameter=optional_param>\n"
|
||||
" null\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Negative numbers and scientific notation
|
||||
common_chat_msg expected_numbers;
|
||||
expected_numbers.role = "assistant";
|
||||
expected_numbers.tool_calls = {
|
||||
{ "math_function", "{\"negative\":-42,\"decimal\":-3.14,\"scientific\":1.23e-4}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_numbers,
|
||||
"<tool_call>\n"
|
||||
" <function=math_function>\n"
|
||||
" <parameter=negative>\n"
|
||||
" -42\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=decimal>\n"
|
||||
" -3.14\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=scientific>\n"
|
||||
" 1.23e-4\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// XML-like content in parameters (should be escaped)
|
||||
common_chat_msg expected_xml_content;
|
||||
expected_xml_content.role = "assistant";
|
||||
expected_xml_content.tool_calls = {
|
||||
{ "xml_function", "{\"xml_content\":\"<root><item>value</item></root>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_xml_content,
|
||||
"<tool_call>\n"
|
||||
" <function=xml_function>\n"
|
||||
" <parameter=xml_content>\n"
|
||||
" <root><item>value</item></root>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Quotes and escape characters
|
||||
common_chat_msg expected_quotes;
|
||||
expected_quotes.role = "assistant";
|
||||
expected_quotes.tool_calls = {
|
||||
{ "quote_function", "{\"message\":\"She said \\\"Hello!\\\" and left.\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_quotes,
|
||||
"<tool_call>\n"
|
||||
" <function=quote_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" She said \"Hello!\" and left.\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Long parameter value (simplified)
|
||||
std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data.";
|
||||
|
||||
common_chat_msg expected_long_text;
|
||||
expected_long_text.role = "assistant";
|
||||
expected_long_text.tool_calls = {
|
||||
{ "long_function", "{\"long_text\":\"" + long_text + "\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_long_text,
|
||||
"<tool_call>\n"
|
||||
" <function=long_function>\n"
|
||||
" <parameter=long_text>\n"
|
||||
" " + long_text + "\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Mixed content with text before and after tool call
|
||||
common_chat_msg expected_mixed_content;
|
||||
expected_mixed_content.role = "assistant";
|
||||
expected_mixed_content.content = "I'll help you search for products. ";
|
||||
expected_mixed_content.tool_calls = {
|
||||
{ "search_function", "{\"query\":\"laptops\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_mixed_content,
|
||||
"I'll help you search for products. <tool_call>\n"
|
||||
" <function=search_function>\n"
|
||||
" <parameter=query>\n"
|
||||
" laptops\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Compact format (no extra whitespace)
|
||||
common_chat_msg expected_compact;
|
||||
expected_compact.role = "assistant";
|
||||
expected_compact.tool_calls = {
|
||||
{ "compact_function", "{\"param\":\"value\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_compact,
|
||||
"<tool_call><function=compact_function><parameter=param>value</parameter></function></tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Function name with underscores and numbers
|
||||
common_chat_msg expected_complex_name;
|
||||
expected_complex_name.role = "assistant";
|
||||
expected_complex_name.tool_calls = {
|
||||
{ "get_user_data_v2", "{\"user_id\":12345}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_complex_name,
|
||||
"<tool_call>\n"
|
||||
" <function=get_user_data_v2>\n"
|
||||
" <parameter=user_id>\n"
|
||||
" 12345\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter names with underscores and numbers
|
||||
common_chat_msg expected_complex_params;
|
||||
expected_complex_params.role = "assistant";
|
||||
expected_complex_params.tool_calls = {
|
||||
{ "test_function", "{\"param_1\":\"value1\",\"param_2_name\":\"value2\",\"param3\":123}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_complex_params,
|
||||
"<tool_call>\n"
|
||||
" <function=test_function>\n"
|
||||
" <parameter=param_1>\n"
|
||||
" value1\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=param_2_name>\n"
|
||||
" value2\n"
|
||||
" </parameter>\n"
|
||||
" <parameter=param3>\n"
|
||||
" 123\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Very deeply nested XML content in parameter
|
||||
common_chat_msg expected_deep_xml;
|
||||
expected_deep_xml.role = "assistant";
|
||||
expected_deep_xml.tool_calls = {
|
||||
{ "xml_parser", "{\"xml\":\"<root><level1><level2><level3>deep content</level3></level2></level1></root>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_deep_xml,
|
||||
"<tool_call>\n"
|
||||
" <function=xml_parser>\n"
|
||||
" <parameter=xml>\n"
|
||||
" <root><level1><level2><level3>deep content</level3></level2></level1></root>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter with only whitespace
|
||||
common_chat_msg expected_whitespace_param;
|
||||
expected_whitespace_param.role = "assistant";
|
||||
expected_whitespace_param.tool_calls = {
|
||||
{ "whitespace_function", "{\"spaces\":\"\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_whitespace_param,
|
||||
"<tool_call>\n"
|
||||
" <function=whitespace_function>\n"
|
||||
" <parameter=spaces>\n"
|
||||
" \n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Parameter with tabs and mixed whitespace
|
||||
common_chat_msg expected_mixed_whitespace;
|
||||
expected_mixed_whitespace.role = "assistant";
|
||||
expected_mixed_whitespace.tool_calls = {
|
||||
{ "tab_function", "{\"content\":\"line1\\n\\tindented line\\n spaces\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_mixed_whitespace,
|
||||
"<tool_call>\n"
|
||||
" <function=tab_function>\n"
|
||||
" <parameter=content>\n"
|
||||
"line1\n"
|
||||
"\tindented line\n"
|
||||
" spaces\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Control characters and special Unicode
|
||||
common_chat_msg expected_control_chars;
|
||||
expected_control_chars.role = "assistant";
|
||||
expected_control_chars.tool_calls = {
|
||||
{ "control_function", "{\"text\":\"Line1\\nLine2\\tTabbed\\rCarriage return\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_control_chars,
|
||||
"<tool_call>\n"
|
||||
" <function=control_function>\n"
|
||||
" <parameter=text>\n"
|
||||
"Line1\nLine2\tTabbed\rCarriage return\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Emoji and extended Unicode characters
|
||||
common_chat_msg expected_emoji;
|
||||
expected_emoji.role = "assistant";
|
||||
expected_emoji.tool_calls = {
|
||||
{ "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_emoji,
|
||||
"<tool_call>\n"
|
||||
" <function=emoji_function>\n"
|
||||
" <parameter=message>\n"
|
||||
" Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Mathematical expressions and formulas
|
||||
common_chat_msg expected_math;
|
||||
expected_math.role = "assistant";
|
||||
expected_math.tool_calls = {
|
||||
{ "math_function", "{\"formula\":\"E = mc² and ∫f(x)dx = F(x) + C\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_math,
|
||||
"<tool_call>\n"
|
||||
" <function=math_function>\n"
|
||||
" <parameter=formula>\n"
|
||||
" E = mc² and ∫f(x)dx = F(x) + C\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// SQL injection-like content (should be safely escaped)
|
||||
common_chat_msg expected_sql;
|
||||
expected_sql.role = "assistant";
|
||||
expected_sql.tool_calls = {
|
||||
{ "sql_function", "{\"query\":\"SELECT * FROM users WHERE id = 1; DROP TABLE users; --\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_sql,
|
||||
"<tool_call>\n"
|
||||
" <function=sql_function>\n"
|
||||
" <parameter=query>\n"
|
||||
" SELECT * FROM users WHERE id = 1; DROP TABLE users; --\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// HTML/XML injection content
|
||||
common_chat_msg expected_html;
|
||||
expected_html.role = "assistant";
|
||||
expected_html.tool_calls = {
|
||||
{ "html_function", "{\"content\":\"<script>alert('xss')</script><img src=x onerror=alert(1)>\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_html,
|
||||
"<tool_call>\n"
|
||||
" <function=html_function>\n"
|
||||
" <parameter=content>\n"
|
||||
" <script>alert('xss')</script><img src=x onerror=alert(1)>\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Binary-like content (base64)
|
||||
common_chat_msg expected_binary;
|
||||
expected_binary.role = "assistant";
|
||||
expected_binary.tool_calls = {
|
||||
{ "binary_function", "{\"data\":\"SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\"}", "" }
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_binary,
|
||||
"<tool_call>\n"
|
||||
" <function=binary_function>\n"
|
||||
" <parameter=data>\n"
|
||||
" SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
|
||||
// Very large numbers (should be parsed as scientific notation)
|
||||
common_chat_msg expected_large_numbers;
|
||||
expected_large_numbers.role = "assistant";
|
||||
expected_large_numbers.tool_calls = {
|
||||
{ "number_function", "{\"big_int\":1e+60}", "" } // Large number becomes scientific notation
|
||||
};
|
||||
|
||||
test_parser_with_streaming(
|
||||
expected_large_numbers,
|
||||
"<tool_call>\n"
|
||||
" <function=number_function>\n"
|
||||
" <parameter=big_int>\n"
|
||||
" 999999999999999999999999999999999999999999999999999999999999\n"
|
||||
" </parameter>\n"
|
||||
" </function>\n"
|
||||
"</tool_call>",
|
||||
[&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); });
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen3-Coder template
|
||||
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user };
|
||||
|
||||
common_chat_tool qwen_union_tool {
|
||||
/* .name = */ "qwen_union",
|
||||
/* .description = */ "Test tool for union/anyOf handling",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": { "type": ["number", "null"] },
|
||||
"maybe_text": { "anyOf": [ { "type": "string" } ] },
|
||||
"config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] }
|
||||
},
|
||||
"required": []
|
||||
})",
|
||||
};
|
||||
inputs.tools = { qwen_union_tool };
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format);
|
||||
assert_equals(false, params.grammar.empty());
|
||||
|
||||
// Grammar should compile successfully
|
||||
auto grammar = build_grammar(params.grammar);
|
||||
GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types");
|
||||
}
|
||||
|
||||
{
|
||||
// Step-3.5-Flash template: uses same XML output format as Qwen3-Coder and Nemotron v3,
|
||||
// but with <think> support. Routes to the Nemotron v3 PEG parser for streaming and
|
||||
@@ -3665,6 +3143,135 @@ static void test_template_output_peg_parsers() {
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen3-Coder
|
||||
auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja");
|
||||
|
||||
// Test basic message
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = "Hello, world!\nWhat's up?";
|
||||
t.expect = message_assist;
|
||||
});
|
||||
|
||||
// Test tool call
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {special_function_tool};
|
||||
t.expect = message_assist_call;
|
||||
});
|
||||
|
||||
// Test parallel tool calls
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=special_function>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=special_function_with_opt>\n"
|
||||
"<parameter=arg1>\n"
|
||||
"1\n"
|
||||
"</parameter>\n"
|
||||
"<parameter=arg2>\n"
|
||||
"2\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.parallel_tool_calls = true;
|
||||
t.params.tools = {special_function_tool, special_function_tool_with_optional_param};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "special_function",
|
||||
/* .arguments = */ R"({"arg1": 1})",
|
||||
/* .id = */ {},
|
||||
}, {
|
||||
/* .name = */ "special_function_with_opt",
|
||||
/* .arguments = */ R"({"arg1": 1, "arg2": 2})",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with string parameter
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {python_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with JSON parameter
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=todo_list>\n"
|
||||
"<parameter=todos>\n"
|
||||
"[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n"
|
||||
"</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {todo_list_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "todo_list",
|
||||
/* .arguments = */ "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test tool call with string parameter and no closing </parameter> tag
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input =
|
||||
"<tool_call>\n"
|
||||
"<function=python>\n"
|
||||
"<parameter=code>\n"
|
||||
"def hello():\n"
|
||||
" print(\"Hello, world!\")\n"
|
||||
"\n"
|
||||
"hello()\n"
|
||||
"</function>\n"
|
||||
"</tool_call>";
|
||||
t.params.tools = {python_tool};
|
||||
|
||||
t.expect.tool_calls = {{
|
||||
/* .name = */ "python",
|
||||
/* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}",
|
||||
/* .id = */ {},
|
||||
}};
|
||||
});
|
||||
|
||||
// Test response format
|
||||
test_peg_parser(tmpls.get(), [&](auto & t) {
|
||||
t.input = R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
t.params.json_schema = invoice_schema;
|
||||
|
||||
t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})";
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
// NVIDIA Nemotron-3 Nano
|
||||
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja");
|
||||
|
||||
Binary file not shown.
@@ -251,9 +251,6 @@
|
||||
return options.find((option) => option.id === activeId);
|
||||
}
|
||||
|
||||
if (options.length === 1) {
|
||||
return options[0];
|
||||
}
|
||||
// No selection - return undefined to show "Select model"
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -306,6 +306,16 @@ class ModelsStore {
|
||||
const response = await ModelsService.listRouter();
|
||||
this.routerModels = response.data;
|
||||
await this.fetchModalitiesForLoadedModels();
|
||||
|
||||
const o = this.models.filter((option) => {
|
||||
const modelProps = this.getModelProps(option.model);
|
||||
|
||||
return modelProps?.webui !== false;
|
||||
});
|
||||
|
||||
if (o.length === 1 && this.isModelLoaded(o[0].model)) {
|
||||
this.selectModelById(o[0].id);
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Failed to fetch router models:', error);
|
||||
this.routerModels = [];
|
||||
|
||||
Vendored
+2398
-132
File diff suppressed because it is too large
Load Diff
Vendored
+447
-18
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.32.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002000"
|
||||
#define CPPHTTPLIB_VERSION "0.33.1"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002101"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
@@ -185,6 +185,14 @@
|
||||
: 0))
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT
|
||||
#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4)
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT
|
||||
#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_RECV_FLAGS
|
||||
#define CPPHTTPLIB_RECV_FLAGS 0
|
||||
#endif
|
||||
@@ -201,6 +209,22 @@
|
||||
#define CPPHTTPLIB_MAX_LINE_LENGTH 32768
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH
|
||||
#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5
|
||||
#endif
|
||||
|
||||
#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND
|
||||
#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Headers
|
||||
*/
|
||||
@@ -310,6 +334,7 @@ using socket_t = int;
|
||||
#include <errno.h>
|
||||
#include <exception>
|
||||
#include <fcntl.h>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
@@ -328,6 +353,9 @@ using socket_t = int;
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#if __cplusplus >= 201703L
|
||||
#include <any>
|
||||
#endif
|
||||
|
||||
#if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \
|
||||
defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
@@ -415,10 +443,46 @@ using socket_t = int;
|
||||
|
||||
#endif // CPPHTTPLIB_MBEDTLS_SUPPORT
|
||||
|
||||
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
#include <wolfssl/options.h>
|
||||
|
||||
#include <wolfssl/openssl/x509v3.h>
|
||||
|
||||
// Fallback definitions for older wolfSSL versions (e.g., 5.6.6)
|
||||
#ifndef WOLFSSL_GEN_EMAIL
|
||||
#define WOLFSSL_GEN_EMAIL 1
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_DNS
|
||||
#define WOLFSSL_GEN_DNS 2
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_URI
|
||||
#define WOLFSSL_GEN_URI 6
|
||||
#endif
|
||||
#ifndef WOLFSSL_GEN_IPADD
|
||||
#define WOLFSSL_GEN_IPADD 7
|
||||
#endif
|
||||
|
||||
#include <wolfssl/ssl.h>
|
||||
#include <wolfssl/wolfcrypt/hash.h>
|
||||
#include <wolfssl/wolfcrypt/md5.h>
|
||||
#include <wolfssl/wolfcrypt/sha256.h>
|
||||
#include <wolfssl/wolfcrypt/sha512.h>
|
||||
#ifdef _WIN32
|
||||
#include <wincrypt.h>
|
||||
#ifdef _MSC_VER
|
||||
#pragma comment(lib, "crypt32.lib")
|
||||
#endif
|
||||
#endif // _WIN32
|
||||
#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
|
||||
#if TARGET_OS_MAC
|
||||
#include <Security/Security.h>
|
||||
#endif
|
||||
#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN
|
||||
#endif // CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
|
||||
// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available
|
||||
// This simplifies conditional compilation when adding new backends (e.g.,
|
||||
// wolfSSL)
|
||||
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT)
|
||||
#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \
|
||||
defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT)
|
||||
#define CPPHTTPLIB_SSL_ENABLED
|
||||
#endif
|
||||
|
||||
@@ -440,6 +504,10 @@ using socket_t = int;
|
||||
*/
|
||||
namespace httplib {
|
||||
|
||||
namespace ws {
|
||||
class WebSocket;
|
||||
} // namespace ws
|
||||
|
||||
namespace detail {
|
||||
|
||||
/*
|
||||
@@ -711,6 +779,143 @@ using Match = std::smatch;
|
||||
using DownloadProgress = std::function<bool(size_t current, size_t total)>;
|
||||
using UploadProgress = std::function<bool(size_t current, size_t total)>;
|
||||
|
||||
|
||||
#if __cplusplus >= 201703L
|
||||
|
||||
using any = std::any;
|
||||
using bad_any_cast = std::bad_any_cast;
|
||||
|
||||
template <typename T> T any_cast(const any &a) { return std::any_cast<T>(a); }
|
||||
template <typename T> T any_cast(any &a) { return std::any_cast<T>(a); }
|
||||
template <typename T> T any_cast(any &&a) {
|
||||
return std::any_cast<T>(std::move(a));
|
||||
}
|
||||
template <typename T> const T *any_cast(const any *a) noexcept {
|
||||
return std::any_cast<T>(a);
|
||||
}
|
||||
template <typename T> T *any_cast(any *a) noexcept {
|
||||
return std::any_cast<T>(a);
|
||||
}
|
||||
|
||||
#else // C++11/14 implementation
|
||||
|
||||
class bad_any_cast : public std::bad_cast {
|
||||
public:
|
||||
const char *what() const noexcept override { return "bad any_cast"; }
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
using any_type_id = const void *;
|
||||
|
||||
// Returns a unique per-type ID without RTTI.
|
||||
// The static address is stable across TUs because function templates are
|
||||
// implicitly inline and the ODR merges their statics into one.
|
||||
template <typename T> any_type_id any_typeid() noexcept {
|
||||
static const char id = 0;
|
||||
return &id;
|
||||
}
|
||||
|
||||
struct any_storage {
|
||||
virtual ~any_storage() = default;
|
||||
virtual std::unique_ptr<any_storage> clone() const = 0;
|
||||
virtual any_type_id type_id() const noexcept = 0;
|
||||
};
|
||||
|
||||
template <typename T> struct any_value final : any_storage {
|
||||
T value;
|
||||
template <typename U> explicit any_value(U &&v) : value(std::forward<U>(v)) {}
|
||||
std::unique_ptr<any_storage> clone() const override {
|
||||
return std::unique_ptr<any_storage>(new any_value<T>(value));
|
||||
}
|
||||
any_type_id type_id() const noexcept override { return any_typeid<T>(); }
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class any {
|
||||
std::unique_ptr<detail::any_storage> storage_;
|
||||
|
||||
public:
|
||||
any() noexcept = default;
|
||||
any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {}
|
||||
any(any &&) noexcept = default;
|
||||
any &operator=(const any &o) {
|
||||
storage_ = o.storage_ ? o.storage_->clone() : nullptr;
|
||||
return *this;
|
||||
}
|
||||
any &operator=(any &&) noexcept = default;
|
||||
|
||||
template <
|
||||
typename T, typename D = typename std::decay<T>::type,
|
||||
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
|
||||
any(T &&v) : storage_(new detail::any_value<D>(std::forward<T>(v))) {}
|
||||
|
||||
template <
|
||||
typename T, typename D = typename std::decay<T>::type,
|
||||
typename std::enable_if<!std::is_same<D, any>::value, int>::type = 0>
|
||||
any &operator=(T &&v) {
|
||||
storage_.reset(new detail::any_value<D>(std::forward<T>(v)));
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool has_value() const noexcept { return storage_ != nullptr; }
|
||||
void reset() noexcept { storage_.reset(); }
|
||||
|
||||
template <typename T> friend T *any_cast(any *a) noexcept;
|
||||
template <typename T> friend const T *any_cast(const any *a) noexcept;
|
||||
};
|
||||
|
||||
template <typename T> T *any_cast(any *a) noexcept {
|
||||
if (!a || !a->storage_) { return nullptr; }
|
||||
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
|
||||
return &static_cast<detail::any_value<T> *>(a->storage_.get())->value;
|
||||
}
|
||||
|
||||
template <typename T> const T *any_cast(const any *a) noexcept {
|
||||
if (!a || !a->storage_) { return nullptr; }
|
||||
if (a->storage_->type_id() != detail::any_typeid<T>()) { return nullptr; }
|
||||
return &static_cast<const detail::any_value<T> *>(a->storage_.get())->value;
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(const any &a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
const U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(*p);
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(any &a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(*p);
|
||||
}
|
||||
|
||||
template <typename T> T any_cast(any &&a) {
|
||||
using U =
|
||||
typename std::remove_cv<typename std::remove_reference<T>::type>::type;
|
||||
U *p = any_cast<U>(&a);
|
||||
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
|
||||
if (!p) { throw bad_any_cast{}; }
|
||||
#else
|
||||
if (!p) { std::abort(); }
|
||||
#endif
|
||||
return static_cast<T>(std::move(*p));
|
||||
}
|
||||
|
||||
#endif // __cplusplus >= 201703L
|
||||
|
||||
struct Response;
|
||||
using ResponseHandler = std::function<bool(const Response &response)>;
|
||||
|
||||
@@ -805,6 +1010,34 @@ struct FormDataProvider {
|
||||
};
|
||||
using FormDataProviderItems = std::vector<FormDataProvider>;
|
||||
|
||||
inline FormDataProvider
|
||||
make_file_provider(const std::string &name, const std::string &filepath,
|
||||
const std::string &filename = std::string(),
|
||||
const std::string &content_type = std::string()) {
|
||||
FormDataProvider fdp;
|
||||
fdp.name = name;
|
||||
fdp.filename = filename.empty() ? filepath : filename;
|
||||
fdp.content_type = content_type;
|
||||
fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool {
|
||||
std::ifstream f(filepath, std::ios::binary);
|
||||
if (!f) { return false; }
|
||||
if (offset > 0) {
|
||||
f.seekg(static_cast<std::streamoff>(offset));
|
||||
if (!f.good()) {
|
||||
sink.done();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
char buf[8192];
|
||||
f.read(buf, sizeof(buf));
|
||||
auto n = static_cast<size_t>(f.gcount());
|
||||
if (n > 0) { return sink.write(buf, n); }
|
||||
sink.done(); // EOF
|
||||
return true;
|
||||
};
|
||||
return fdp;
|
||||
}
|
||||
|
||||
using ContentReceiverWithProgress = std::function<bool(
|
||||
const char *data, size_t data_length, size_t offset, size_t total_length)>;
|
||||
|
||||
@@ -1010,6 +1243,10 @@ struct Response {
|
||||
std::string body;
|
||||
std::string location; // Redirect location
|
||||
|
||||
// User-defined context — set by pre-routing/pre-request handlers and read
|
||||
// by route handlers to pass arbitrary data (e.g. decoded auth tokens).
|
||||
std::map<std::string, any> user_data;
|
||||
|
||||
bool has_header(const std::string &key) const;
|
||||
std::string get_header_value(const std::string &key, const char *def = "",
|
||||
size_t id = 0) const;
|
||||
@@ -1124,6 +1361,11 @@ public:
|
||||
|
||||
virtual time_t duration() const = 0;
|
||||
|
||||
virtual void set_read_timeout(time_t sec, time_t usec = 0) {
|
||||
(void)sec;
|
||||
(void)usec;
|
||||
}
|
||||
|
||||
ssize_t write(const char *ptr);
|
||||
ssize_t write(const std::string &s);
|
||||
|
||||
@@ -1146,7 +1388,7 @@ public:
|
||||
|
||||
class ThreadPool final : public TaskQueue {
|
||||
public:
|
||||
explicit ThreadPool(size_t n, size_t mqr = 0);
|
||||
explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0);
|
||||
ThreadPool(const ThreadPool &) = delete;
|
||||
~ThreadPool() override = default;
|
||||
|
||||
@@ -1154,20 +1396,22 @@ public:
|
||||
void shutdown() override;
|
||||
|
||||
private:
|
||||
struct worker {
|
||||
explicit worker(ThreadPool &pool);
|
||||
void worker(bool is_dynamic);
|
||||
void move_to_finished(std::thread::id id);
|
||||
void cleanup_finished_threads();
|
||||
|
||||
void operator()();
|
||||
|
||||
ThreadPool &pool_;
|
||||
};
|
||||
friend struct worker;
|
||||
|
||||
std::vector<std::thread> threads_;
|
||||
std::list<std::function<void()>> jobs_;
|
||||
size_t base_thread_count_;
|
||||
size_t max_thread_count_;
|
||||
size_t max_queued_requests_;
|
||||
size_t idle_thread_count_;
|
||||
|
||||
bool shutdown_;
|
||||
size_t max_queued_requests_ = 0;
|
||||
|
||||
std::list<std::function<void()>> jobs_;
|
||||
std::vector<std::thread> threads_; // base threads
|
||||
std::list<std::thread> dynamic_threads_; // dynamic threads
|
||||
std::vector<std::thread>
|
||||
finished_threads_; // exited dynamic threads awaiting join
|
||||
|
||||
std::condition_variable cond_;
|
||||
std::mutex mutex_;
|
||||
@@ -1294,6 +1538,11 @@ public:
|
||||
using Expect100ContinueHandler =
|
||||
std::function<int(const Request &, Response &)>;
|
||||
|
||||
using WebSocketHandler =
|
||||
std::function<void(const Request &, ws::WebSocket &)>;
|
||||
using SubProtocolSelector =
|
||||
std::function<std::string(const std::vector<std::string> &protocols)>;
|
||||
|
||||
Server();
|
||||
|
||||
virtual ~Server();
|
||||
@@ -1311,6 +1560,10 @@ public:
|
||||
Server &Delete(const std::string &pattern, HandlerWithContentReader handler);
|
||||
Server &Options(const std::string &pattern, Handler handler);
|
||||
|
||||
Server &WebSocket(const std::string &pattern, WebSocketHandler handler);
|
||||
Server &WebSocket(const std::string &pattern, WebSocketHandler handler,
|
||||
SubProtocolSelector sub_protocol_selector);
|
||||
|
||||
bool set_base_dir(const std::string &dir,
|
||||
const std::string &mount_point = std::string());
|
||||
bool set_mount_point(const std::string &mount_point, const std::string &dir,
|
||||
@@ -1386,7 +1639,8 @@ protected:
|
||||
int remote_port, const std::string &local_addr,
|
||||
int local_port, bool close_connection,
|
||||
bool &connection_closed,
|
||||
const std::function<void(Request &)> &setup_request);
|
||||
const std::function<void(Request &)> &setup_request,
|
||||
bool *websocket_upgraded = nullptr);
|
||||
|
||||
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
|
||||
|
||||
@@ -1488,6 +1742,14 @@ private:
|
||||
HandlersForContentReader delete_handlers_for_content_reader_;
|
||||
Handlers options_handlers_;
|
||||
|
||||
struct WebSocketHandlerEntry {
|
||||
std::unique_ptr<detail::MatcherBase> matcher;
|
||||
WebSocketHandler handler;
|
||||
SubProtocolSelector sub_protocol_selector;
|
||||
};
|
||||
using WebSocketHandlers = std::vector<WebSocketHandlerEntry>;
|
||||
WebSocketHandlers websocket_handlers_;
|
||||
|
||||
HandlerWithResponse error_handler_;
|
||||
ExceptionHandler exception_handler_;
|
||||
HandlerWithResponse pre_routing_handler_;
|
||||
@@ -2970,6 +3232,36 @@ struct MbedTlsContext {
|
||||
} // namespace tls
|
||||
#endif
|
||||
|
||||
#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT
|
||||
namespace tls {
|
||||
namespace impl {
|
||||
|
||||
// wolfSSL context wrapper (holds WOLFSSL_CTX and related state).
|
||||
// This struct is accessible via tls::impl for use in SSL context
|
||||
// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*).
|
||||
struct WolfSSLContext {
|
||||
WOLFSSL_CTX *ctx = nullptr;
|
||||
bool is_server = false;
|
||||
bool verify_client = false;
|
||||
bool has_verify_callback = false;
|
||||
std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs
|
||||
|
||||
WolfSSLContext();
|
||||
~WolfSSLContext();
|
||||
|
||||
WolfSSLContext(const WolfSSLContext &) = delete;
|
||||
WolfSSLContext &operator=(const WolfSSLContext &) = delete;
|
||||
};
|
||||
|
||||
// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx
|
||||
struct WolfSSLCAStore {
|
||||
std::string pem_data;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace tls
|
||||
#endif
|
||||
|
||||
#endif // CPPHTTPLIB_SSL_ENABLED
|
||||
|
||||
namespace stream {
|
||||
@@ -3335,6 +3627,143 @@ private:
|
||||
|
||||
} // namespace sse
|
||||
|
||||
namespace ws {
|
||||
|
||||
enum class Opcode : uint8_t {
|
||||
Continuation = 0x0,
|
||||
Text = 0x1,
|
||||
Binary = 0x2,
|
||||
Close = 0x8,
|
||||
Ping = 0x9,
|
||||
Pong = 0xA,
|
||||
};
|
||||
|
||||
enum class CloseStatus : uint16_t {
|
||||
Normal = 1000,
|
||||
GoingAway = 1001,
|
||||
ProtocolError = 1002,
|
||||
UnsupportedData = 1003,
|
||||
NoStatus = 1005,
|
||||
Abnormal = 1006,
|
||||
InvalidPayload = 1007,
|
||||
PolicyViolation = 1008,
|
||||
MessageTooBig = 1009,
|
||||
MandatoryExtension = 1010,
|
||||
InternalError = 1011,
|
||||
};
|
||||
|
||||
enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 };
|
||||
|
||||
class WebSocket {
|
||||
public:
|
||||
WebSocket(const WebSocket &) = delete;
|
||||
WebSocket &operator=(const WebSocket &) = delete;
|
||||
~WebSocket();
|
||||
|
||||
ReadResult read(std::string &msg);
|
||||
bool send(const std::string &data);
|
||||
bool send(const char *data, size_t len);
|
||||
void close(CloseStatus status = CloseStatus::Normal,
|
||||
const std::string &reason = "");
|
||||
const Request &request() const;
|
||||
bool is_open() const;
|
||||
|
||||
private:
|
||||
friend class httplib::Server;
|
||||
friend class WebSocketClient;
|
||||
|
||||
WebSocket(Stream &strm, const Request &req, bool is_server)
|
||||
: strm_(strm), req_(req), is_server_(is_server) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
WebSocket(std::unique_ptr<Stream> &&owned_strm, const Request &req,
|
||||
bool is_server)
|
||||
: strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req),
|
||||
is_server_(is_server) {
|
||||
start_heartbeat();
|
||||
}
|
||||
|
||||
void start_heartbeat();
|
||||
bool send_frame(Opcode op, const char *data, size_t len, bool fin = true);
|
||||
|
||||
Stream &strm_;
|
||||
std::unique_ptr<Stream> owned_strm_;
|
||||
Request req_;
|
||||
bool is_server_;
|
||||
std::atomic<bool> closed_{false};
|
||||
std::mutex write_mutex_;
|
||||
std::thread ping_thread_;
|
||||
std::mutex ping_mutex_;
|
||||
std::condition_variable ping_cv_;
|
||||
};
|
||||
|
||||
class WebSocketClient {
|
||||
public:
|
||||
explicit WebSocketClient(const std::string &scheme_host_port_path,
|
||||
const Headers &headers = {});
|
||||
|
||||
~WebSocketClient();
|
||||
WebSocketClient(const WebSocketClient &) = delete;
|
||||
WebSocketClient &operator=(const WebSocketClient &) = delete;
|
||||
|
||||
bool is_valid() const;
|
||||
|
||||
bool connect();
|
||||
ReadResult read(std::string &msg);
|
||||
bool send(const std::string &data);
|
||||
bool send(const char *data, size_t len);
|
||||
void close(CloseStatus status = CloseStatus::Normal,
|
||||
const std::string &reason = "");
|
||||
bool is_open() const;
|
||||
const std::string &subprotocol() const;
|
||||
void set_read_timeout(time_t sec, time_t usec = 0);
|
||||
void set_write_timeout(time_t sec, time_t usec = 0);
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
void set_ca_cert_path(const std::string &path);
|
||||
void set_ca_cert_store(tls::ca_store_t store);
|
||||
void enable_server_certificate_verification(bool enabled);
|
||||
#endif
|
||||
|
||||
private:
|
||||
void shutdown_and_close();
|
||||
bool create_stream(std::unique_ptr<Stream> &strm);
|
||||
|
||||
std::string host_;
|
||||
int port_;
|
||||
std::string path_;
|
||||
Headers headers_;
|
||||
std::string subprotocol_;
|
||||
bool is_valid_ = false;
|
||||
socket_t sock_ = INVALID_SOCKET;
|
||||
std::unique_ptr<WebSocket> ws_;
|
||||
time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND;
|
||||
time_t read_timeout_usec_ = 0;
|
||||
time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND;
|
||||
time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND;
|
||||
|
||||
#ifdef CPPHTTPLIB_SSL_ENABLED
|
||||
bool is_ssl_ = false;
|
||||
tls::ctx_t tls_ctx_ = nullptr;
|
||||
tls::session_t tls_session_ = nullptr;
|
||||
std::string ca_cert_file_path_;
|
||||
tls::ca_store_t ca_cert_store_ = nullptr;
|
||||
bool server_certificate_verification_ = true;
|
||||
#endif
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
bool is_valid_utf8(const std::string &s);
|
||||
|
||||
bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload,
|
||||
bool &fin, bool expect_masked, size_t max_len);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
} // namespace ws
|
||||
|
||||
|
||||
} // namespace httplib
|
||||
|
||||
|
||||
Reference in New Issue
Block a user