forked from wylab/llama.cpp
Compare commits
12 Commits
b6516
..
llama-pull
| Author | SHA1 | Date | |
|---|---|---|---|
| 17ca6ed540 | |||
| 7f766929ca | |||
| 405921dcef | |||
| fa6383ca7e | |||
| 803dac2e48 | |||
| 459c0c2c1a | |||
| be79d9fdd9 | |||
| f432d8d83e | |||
| 4067f07fc5 | |||
| 4b8560ab56 | |||
| 0dd58b6877 | |||
| 69ffd89163 |
+27
-15
@@ -1741,10 +1741,12 @@ static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
|
||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ json(), json {
|
||||
const std::optional<json> tools_override = json();
|
||||
const std::optional<json> additional_context = json {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
@@ -2230,15 +2232,28 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
static const common_regex start_think_regex(regex_escape("<think>"));
|
||||
static const common_regex end_think_regex(regex_escape("</think>"));
|
||||
// Granite models output partial tokens such as "<" and "<think".
|
||||
// By leveraging try_consume_regex()/try_find_regex() throwing
|
||||
// common_chat_msg_partial_exception for these partial tokens,
|
||||
// processing is interrupted and the tokens are not passed to add_content().
|
||||
if (auto res = builder.try_consume_regex(start_think_regex)) {
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
builder.try_find_regex(end_think_regex, std::string::npos, false);
|
||||
// Restore position for try_parse_reasoning()
|
||||
builder.move_to(res->groups[0].begin);
|
||||
}
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags using regex
|
||||
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
|
||||
if (auto res = builder.try_find_regex(response_regex)) {
|
||||
// Extract the content between the tags (capture group 1)
|
||||
auto content = builder.str(res->groups[1]);
|
||||
builder.add_content(content);
|
||||
builder.move_to(res->groups[0].end);
|
||||
// Parse response tags
|
||||
static const common_regex start_response_regex(regex_escape("<response>"));
|
||||
static const common_regex end_response_regex(regex_escape("</response>"));
|
||||
// Granite models output partial tokens such as "<" and "<response".
|
||||
// Same hack as reasoning parsing.
|
||||
if (builder.try_consume_regex(start_response_regex)) {
|
||||
builder.try_find_regex(end_response_regex);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
@@ -2252,13 +2267,10 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.add_tool_calls(tool_calls_data.json)) {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
|
||||
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
|
||||
throw common_chat_msg_partial_exception("incomplete tool call");
|
||||
}
|
||||
} else {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
|
||||
+37
-21
@@ -1,5 +1,41 @@
|
||||
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
|
||||
project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
if(GIT_EXE)
|
||||
# Get current git commit hash
|
||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
ERROR_QUIET
|
||||
)
|
||||
|
||||
# Check if the working directory is dirty (i.e., has uncommitted changes)
|
||||
execute_process(COMMAND ${GIT_EXE} diff-index --quiet HEAD -- .
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE GGML_GIT_DIRTY
|
||||
ERROR_QUIET
|
||||
)
|
||||
endif()
|
||||
|
||||
# Build the version string with optional -dev suffix and dirty flag
|
||||
set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}")
|
||||
if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0)
|
||||
set(GGML_VERSION "${GGML_VERSION}-dirty")
|
||||
endif()
|
||||
|
||||
if(NOT GGML_BUILD_COMMIT)
|
||||
set(GGML_BUILD_COMMIT "unknown")
|
||||
endif()
|
||||
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
@@ -300,26 +336,6 @@ endif()
|
||||
# Create CMake package
|
||||
#
|
||||
|
||||
# Generate version info based on git commit.
|
||||
|
||||
if(NOT DEFINED GGML_BUILD_NUMBER)
|
||||
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
|
||||
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_NUMBER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(GGML_BUILD_NUMBER EQUAL 1)
|
||||
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Capture variables prefixed with GGML_.
|
||||
@@ -348,7 +364,7 @@ set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
||||
|
||||
# Create the CMake package and set install location.
|
||||
|
||||
set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
|
||||
set(GGML_INSTALL_VERSION ${GGML_VERSION})
|
||||
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "traits.h"
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
@@ -186,7 +186,7 @@ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
||||
#define XFEATURE_XTILEDATA 18
|
||||
|
||||
static bool ggml_amx_init() {
|
||||
#if defined(__gnu_linux__)
|
||||
#if defined(__linux__)
|
||||
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
||||
fprintf(stderr, "AMX is not ready to be used!\n");
|
||||
return false;
|
||||
@@ -194,6 +194,8 @@ static bool ggml_amx_init() {
|
||||
return true;
|
||||
#elif defined(_WIN32)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,14 @@ static inline float bf16_to_f32(ggml_bf16_t x) {
|
||||
return GGML_BF16_TO_FP32(x);
|
||||
}
|
||||
|
||||
static inline float i32_to_f32(int32_t x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline int32_t f32_to_i32(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float f32_to_f32(float x) {
|
||||
return x;
|
||||
}
|
||||
@@ -54,6 +62,12 @@ struct type_conversion_table<ggml_bf16_t> {
|
||||
static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_conversion_table<int32_t> {
|
||||
static constexpr float (*to_f32)(int32_t) = i32_to_f32;
|
||||
static constexpr int32_t (*from_f32)(float) = f32_to_i32;
|
||||
};
|
||||
|
||||
static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
|
||||
const int64_t ith = params->ith;
|
||||
const int64_t nth = params->nth;
|
||||
|
||||
+75
-852
File diff suppressed because it is too large
Load Diff
@@ -25,10 +25,14 @@ if (CUDAToolkit_FOUND)
|
||||
if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "native")
|
||||
else()
|
||||
if (CUDAToolkit_VERSION VERSION_LESS "13")
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual)
|
||||
endif ()
|
||||
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)
|
||||
|
||||
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
|
||||
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real")
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real")
|
||||
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -31,10 +31,10 @@
|
||||
#include "types.comp"
|
||||
|
||||
#ifndef LOAD_VEC_A
|
||||
#define LOAD_VEC_A 1
|
||||
#define LOAD_VEC_A 2
|
||||
#endif
|
||||
#ifndef LOAD_VEC_B
|
||||
#define LOAD_VEC_B 1
|
||||
#define LOAD_VEC_B 2
|
||||
#endif
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
@@ -98,13 +98,13 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
||||
layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK + 8)
|
||||
#define SHMEM_STRIDE (BK / 2 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK + 1)
|
||||
#define SHMEM_STRIDE (BK / 2 + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
@@ -302,8 +302,8 @@ void main() {
|
||||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
@@ -312,13 +312,13 @@ void main() {
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
||||
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a, end_k);
|
||||
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
|
||||
}
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if !defined(MUL_MAT_ID)
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b, end_k);
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
|
||||
#else
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b, end_k);
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -331,17 +331,17 @@ void main() {
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < BK; i++) {
|
||||
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
@@ -357,7 +357,7 @@ void main() {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,51 +1,53 @@
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint idx_k, const uint end_k) {
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].x;
|
||||
buf_a[buf_idx + 1] = aa[0].y;
|
||||
buf_a[buf_idx + 2] = aa[0].z;
|
||||
buf_a[buf_idx + 3] = aa[0].w;
|
||||
buf_a[buf_idx + 4] = aa[1].x;
|
||||
buf_a[buf_idx + 5] = aa[1].y;
|
||||
buf_a[buf_idx + 6] = aa[1].z;
|
||||
buf_a[buf_idx + 7] = aa[1].w;
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.x;
|
||||
buf_a[buf_idx + 1] = aa.y;
|
||||
buf_a[buf_idx + 2] = aa.z;
|
||||
buf_a[buf_idx + 3] = aa.w;
|
||||
#else
|
||||
if (idx_m < p.M && idx_k < end_k) {
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_A == 2
|
||||
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.x;
|
||||
buf_a[buf_idx + 1] = aa.y;
|
||||
buf_a[buf_idx + 2] = aa.z;
|
||||
buf_a[buf_idx + 3] = aa.w;
|
||||
#else
|
||||
if (idx_m < p.M && idx_k < end_k) {
|
||||
buf_a[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_A == 2
|
||||
const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(uint16_t(0));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
@@ -55,17 +57,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
@@ -76,17 +74,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@@ -99,13 +93,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@@ -119,13 +111,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@@ -135,13 +125,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -156,11 +144,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -178,11 +165,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
|
||||
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -211,11 +198,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -247,11 +234,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
@@ -266,11 +253,11 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
|
||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
||||
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
@@ -283,12 +270,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
@@ -304,12 +292,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
@@ -330,17 +319,17 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
@@ -356,17 +345,17 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
@@ -384,17 +373,17 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
@@ -414,13 +403,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
@@ -436,13 +425,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
@@ -457,11 +446,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
@@ -469,13 +457,13 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
@@ -484,84 +472,84 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint idx_k, const uint end_k) {
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].x;
|
||||
buf_b[buf_idx + 1] = bb[0].y;
|
||||
buf_b[buf_idx + 2] = bb[0].z;
|
||||
buf_b[buf_idx + 3] = bb[0].w;
|
||||
buf_b[buf_idx + 4] = bb[1].x;
|
||||
buf_b[buf_idx + 5] = bb[1].y;
|
||||
buf_b[buf_idx + 6] = bb[1].z;
|
||||
buf_b[buf_idx + 7] = bb[1].w;
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.x;
|
||||
buf_b[buf_idx + 1] = bb.y;
|
||||
buf_b[buf_idx + 2] = bb.z;
|
||||
buf_b[buf_idx + 3] = bb.w;
|
||||
#else // LOAD_VEC_B == 1
|
||||
if (idx_n < p.N && idx_k < end_k) {
|
||||
buf_b[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_b[pos_b + col * p.stride_b + row]);
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_B == 2
|
||||
const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint idx_k, const uint end_k) {
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].x;
|
||||
buf_b[buf_idx + 1] = bb[0].y;
|
||||
buf_b[buf_idx + 2] = bb[0].z;
|
||||
buf_b[buf_idx + 3] = bb[0].w;
|
||||
buf_b[buf_idx + 4] = bb[1].x;
|
||||
buf_b[buf_idx + 5] = bb[1].y;
|
||||
buf_b[buf_idx + 6] = bb[1].z;
|
||||
buf_b[buf_idx + 7] = bb[1].w;
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.x;
|
||||
buf_b[buf_idx + 1] = bb.y;
|
||||
buf_b[buf_idx + 2] = bb.z;
|
||||
buf_b[buf_idx + 3] = bb.w;
|
||||
#else // LOAD_VEC_B == 1
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
if (row_i < _ne1 && idx_k < end_k) {
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
buf_b[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row]);
|
||||
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE float
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#define A_TYPE mat2x4
|
||||
#else
|
||||
#define A_TYPE float
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -24,12 +24,12 @@
|
||||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE float16_t
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE f16vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#define A_TYPE f16mat2x4
|
||||
#else
|
||||
#define A_TYPE float16_t
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@@ -37,12 +37,12 @@
|
||||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE uint16_t
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE u16vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#error unsupported
|
||||
#else
|
||||
#define A_TYPE uint16_t
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
@@ -336,7 +336,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
}
|
||||
@@ -418,7 +419,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
// bf16
|
||||
{
|
||||
std::string load_vec_a_unaligned = "1";
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = coopmat2 ? "1" : "4";
|
||||
|
||||
@@ -436,8 +436,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
if (!(coopmat || coopmat2))
|
||||
#endif
|
||||
{
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||
std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant;
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
323951f1bdcdfbd5b5ff3a9a7c3770e63b1a560e
|
||||
978f6e1993f2eeb4e99b63d4e70b4401c0a2dae2
|
||||
|
||||
@@ -6629,9 +6629,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
|
||||
@@ -1402,6 +1402,12 @@ static void test_template_output_parsers() {
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(
|
||||
message_assist,
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ true,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
|
||||
// Test parsing content with thinking
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
@@ -1412,6 +1418,59 @@ static void test_template_output_parsers() {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><response>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ true,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><response>Hello, world!\nWhat's up?</response>",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
assert_msg_equals(simple_assist_msg("<think>I'm\nthinking</think><response>Hello, world!\nWhat's up?</response>"),
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><response>Hello, world!\nWhat's up?</response>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(message_assist_empty,
|
||||
common_chat_parse(
|
||||
"<think",
|
||||
/* is_partial= */ true,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
assert_msg_equals(message_assist_empty,
|
||||
common_chat_parse(
|
||||
"<think",
|
||||
/* is_partial= */ true,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(message_assist_thoughts_no_content,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking",
|
||||
/* is_partial= */ true,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
assert_msg_equals(
|
||||
message_assist_empty,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><response",
|
||||
/* is_partial= */ true,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
|
||||
// Test parsing tool calls
|
||||
assert_msg_equals(message_assist_call,
|
||||
@@ -1419,6 +1478,38 @@ static void test_template_output_parsers() {
|
||||
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(
|
||||
message_assist_call_empty_args,
|
||||
common_chat_parse(
|
||||
"<|tool_call|>[{\"name\": \"special_function\"",
|
||||
/* is_partial= */ true,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(
|
||||
message_assist_call_cutoff_args,
|
||||
common_chat_parse(
|
||||
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg",
|
||||
/* is_partial= */ true,
|
||||
{COMMON_CHAT_FORMAT_GRANITE}));
|
||||
assert_msg_equals(
|
||||
message_assist_call_cutoff_args,
|
||||
common_chat_parse(
|
||||
"<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg",
|
||||
/* is_partial= */ true,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
|
||||
// Test parsing tool calls with thinking
|
||||
assert_msg_equals(
|
||||
message_assist_call_thoughts,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, {",
|
||||
/* is_partial= */ true,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GRANITE,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
|
||||
// Test template generation for regular content
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools,
|
||||
|
||||
@@ -18,6 +18,7 @@ else()
|
||||
add_subdirectory(gguf-split)
|
||||
add_subdirectory(imatrix)
|
||||
add_subdirectory(llama-bench)
|
||||
add_subdirectory(pull)
|
||||
add_subdirectory(main)
|
||||
add_subdirectory(perplexity)
|
||||
add_subdirectory(quantize)
|
||||
|
||||
@@ -30,8 +30,10 @@ options:
|
||||
--delay <0...N> (seconds) delay between each test (default: 0)
|
||||
-o, --output <csv|json|jsonl|md|sql> output format printed to stdout (default: md)
|
||||
-oe, --output-err <csv|json|jsonl|md|sql> output format printed to stderr (default: none)
|
||||
--list-devices list available devices and exit
|
||||
-v, --verbose verbose output
|
||||
--progress print test progress indicators
|
||||
-rpc, --rpc <rpc_servers> register RPC devices (comma separated)
|
||||
|
||||
test parameters:
|
||||
-m, --model <filename> (default: models/7B/ggml-model-q4_0.gguf)
|
||||
@@ -48,11 +50,12 @@ test parameters:
|
||||
--cpu-strict <0|1> (default: 0)
|
||||
--poll <0...100> (default: 50)
|
||||
-ngl, --n-gpu-layers <n> (default: 99)
|
||||
-rpc, --rpc <rpc_servers> (default: none)
|
||||
-ncmoe, --n-cpu-moe <n> (default: 0)
|
||||
-sm, --split-mode <none|layer|row> (default: layer)
|
||||
-mg, --main-gpu <i> (default: 0)
|
||||
-nkvo, --no-kv-offload <0|1> (default: 0)
|
||||
-fa, --flash-attn <0|1> (default: 0)
|
||||
-dev, --device <dev0/dev1/...> (default: auto)
|
||||
-mmp, --mmap <0|1> (default: 1)
|
||||
-embd, --embeddings <0|1> (default: 0)
|
||||
-ts, --tensor-split <ts0/ts1/..> (default: 0)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
@@ -135,6 +136,101 @@ static std::string get_gpu_info() {
|
||||
return join(gpu_list, ", ");
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> parse_devices_arg(const std::string & value) {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
std::string trimmed = string_strip(value);
|
||||
if (trimmed.empty()) {
|
||||
throw std::invalid_argument("no devices specified");
|
||||
}
|
||||
if (trimmed == "auto") {
|
||||
return devices;
|
||||
}
|
||||
|
||||
auto dev_names = string_split<std::string>(trimmed, '/');
|
||||
if (dev_names.size() == 1 && string_strip(dev_names[0]) == "none") {
|
||||
devices.push_back(nullptr);
|
||||
return devices;
|
||||
}
|
||||
|
||||
for (auto & name : dev_names) {
|
||||
std::string dev_name = string_strip(name);
|
||||
if (dev_name.empty()) {
|
||||
throw std::invalid_argument("invalid device specification");
|
||||
}
|
||||
auto * dev = ggml_backend_dev_by_name(dev_name.c_str());
|
||||
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
throw std::invalid_argument(string_format("invalid device: %s", dev_name.c_str()));
|
||||
}
|
||||
devices.push_back(dev);
|
||||
}
|
||||
|
||||
devices.push_back(nullptr);
|
||||
return devices;
|
||||
}
|
||||
|
||||
static std::vector<ggml_backend_dev_t> register_rpc_device_list(const std::string & servers) {
|
||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||
if (rpc_servers.empty()) {
|
||||
throw std::invalid_argument("no RPC servers specified");
|
||||
}
|
||||
|
||||
auto * rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
}
|
||||
|
||||
using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
|
||||
auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
throw std::invalid_argument("failed to find RPC device add function");
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> registered;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (const auto & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = nullptr;
|
||||
|
||||
std::string name = string_format("RPC[%s]", server.c_str());
|
||||
|
||||
if (registered.find(server) != registered.end()) {
|
||||
dev = ggml_backend_dev_by_name(name.c_str());
|
||||
}
|
||||
|
||||
if (!dev) {
|
||||
dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (!dev) {
|
||||
throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
|
||||
}
|
||||
ggml_backend_device_register(dev);
|
||||
registered.insert(server);
|
||||
}
|
||||
|
||||
devices.push_back(dev);
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
static std::string devices_to_string(const std::vector<ggml_backend_dev_t> & devices) {
|
||||
if (devices.empty()) {
|
||||
return "auto";
|
||||
}
|
||||
|
||||
if (devices.size() == 1 && devices[0] == nullptr) {
|
||||
return "none";
|
||||
}
|
||||
|
||||
std::vector<std::string> names;
|
||||
for (auto * dev : devices) {
|
||||
if (dev == nullptr) {
|
||||
break;
|
||||
}
|
||||
names.push_back(ggml_backend_dev_name(dev));
|
||||
}
|
||||
|
||||
return join(names, "/");
|
||||
}
|
||||
|
||||
// command line params
|
||||
enum output_formats { NONE, CSV, JSON, JSONL, MARKDOWN, SQL };
|
||||
|
||||
@@ -251,11 +347,11 @@ struct cmd_params {
|
||||
std::vector<int> poll;
|
||||
std::vector<int> n_gpu_layers;
|
||||
std::vector<int> n_cpu_moe;
|
||||
std::vector<std::string> rpc_servers;
|
||||
std::vector<llama_split_mode> split_mode;
|
||||
std::vector<int> main_gpu;
|
||||
std::vector<bool> no_kv_offload;
|
||||
std::vector<bool> flash_attn;
|
||||
std::vector<std::vector<ggml_backend_dev_t>> devices;
|
||||
std::vector<std::vector<float>> tensor_split;
|
||||
std::vector<std::vector<llama_model_tensor_buft_override>> tensor_buft_overrides;
|
||||
std::vector<bool> use_mmap;
|
||||
@@ -288,11 +384,11 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* poll */ { 50 },
|
||||
/* n_gpu_layers */ { 99 },
|
||||
/* n_cpu_moe */ { 0 },
|
||||
/* rpc_servers */ { "" },
|
||||
/* split_mode */ { LLAMA_SPLIT_MODE_LAYER },
|
||||
/* main_gpu */ { 0 },
|
||||
/* no_kv_offload */ { false },
|
||||
/* flash_attn */ { false },
|
||||
/* devices */ { {} },
|
||||
/* tensor_split */ { std::vector<float>(llama_max_devices(), 0.0f) },
|
||||
/* tensor_buft_overrides*/ { std::vector<llama_model_tensor_buft_override>{ { nullptr, nullptr } } },
|
||||
/* use_mmap */ { true },
|
||||
@@ -325,9 +421,13 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
output_format_str(cmd_params_defaults.output_format));
|
||||
printf(" -oe, --output-err <csv|json|jsonl|md|sql> output format printed to stderr (default: %s)\n",
|
||||
output_format_str(cmd_params_defaults.output_format_stderr));
|
||||
printf(" --list-devices list available devices and exit\n");
|
||||
printf(" -v, --verbose verbose output\n");
|
||||
printf(" --progress print test progress indicators\n");
|
||||
printf(" --no-warmup skip warmup runs before benchmarking\n");
|
||||
if (llama_supports_rpc()) {
|
||||
printf(" -rpc, --rpc <rpc_servers> register RPC devices (comma separated)\n");
|
||||
}
|
||||
printf("\n");
|
||||
printf("test parameters:\n");
|
||||
printf(" -m, --model <filename> (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
|
||||
@@ -357,10 +457,6 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
||||
printf(" -ncmoe, --n-cpu-moe <n> (default: %s)\n",
|
||||
join(cmd_params_defaults.n_cpu_moe, ",").c_str());
|
||||
if (llama_supports_rpc()) {
|
||||
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n",
|
||||
join(cmd_params_defaults.rpc_servers, ",").c_str());
|
||||
}
|
||||
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n",
|
||||
join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
|
||||
printf(" -mg, --main-gpu <i> (default: %s)\n",
|
||||
@@ -369,6 +465,7 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
join(cmd_params_defaults.no_kv_offload, ",").c_str());
|
||||
printf(" -fa, --flash-attn <0|1> (default: %s)\n",
|
||||
join(cmd_params_defaults.flash_attn, ",").c_str());
|
||||
printf(" -dev, --device <dev0/dev1/...> (default: auto)\n");
|
||||
printf(" -mmp, --mmap <0|1> (default: %s)\n",
|
||||
join(cmd_params_defaults.use_mmap, ",").c_str());
|
||||
printf(" -embd, --embeddings <0|1> (default: %s)\n",
|
||||
@@ -533,6 +630,42 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
|
||||
} else if (arg == "-dev" || arg == "--device") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto combos = string_split<std::string>(argv[i], split_delim);
|
||||
for (const auto & combo : combos) {
|
||||
try {
|
||||
params.devices.push_back(parse_devices_arg(combo));
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "error: %s\n", e.what());
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (invalid_param) {
|
||||
break;
|
||||
}
|
||||
} else if (arg == "--list-devices") {
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||
devices.push_back(dev);
|
||||
}
|
||||
}
|
||||
printf("Available devices:\n");
|
||||
if (devices.empty()) {
|
||||
printf(" (none)\n");
|
||||
}
|
||||
for (auto * dev : devices) {
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
exit(0);
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -580,7 +713,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rpc_servers.push_back(argv[i]);
|
||||
try {
|
||||
register_rpc_device_list(argv[i]);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "error: %s\n", e.what());
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
} else if (arg == "-sm" || arg == "--split-mode") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -855,9 +994,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.n_cpu_moe.empty()) {
|
||||
params.n_cpu_moe = cmd_params_defaults.n_cpu_moe;
|
||||
}
|
||||
if (params.rpc_servers.empty()) {
|
||||
params.rpc_servers = cmd_params_defaults.rpc_servers;
|
||||
}
|
||||
if (params.split_mode.empty()) {
|
||||
params.split_mode = cmd_params_defaults.split_mode;
|
||||
}
|
||||
@@ -870,6 +1006,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.flash_attn.empty()) {
|
||||
params.flash_attn = cmd_params_defaults.flash_attn;
|
||||
}
|
||||
if (params.devices.empty()) {
|
||||
params.devices = cmd_params_defaults.devices;
|
||||
}
|
||||
if (params.tensor_split.empty()) {
|
||||
params.tensor_split = cmd_params_defaults.tensor_split;
|
||||
}
|
||||
@@ -916,11 +1055,11 @@ struct cmd_params_instance {
|
||||
int poll;
|
||||
int n_gpu_layers;
|
||||
int n_cpu_moe;
|
||||
std::string rpc_servers_str;
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
std::vector<float> tensor_split;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
bool use_mmap;
|
||||
@@ -931,57 +1070,8 @@ struct cmd_params_instance {
|
||||
llama_model_params mparams = llama_model_default_params();
|
||||
|
||||
mparams.n_gpu_layers = n_gpu_layers;
|
||||
if (!rpc_servers_str.empty()) {
|
||||
auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
|
||||
|
||||
// add RPC devices
|
||||
if (!rpc_servers.empty()) {
|
||||
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
|
||||
if (!rpc_reg) {
|
||||
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
static std::vector<ggml_backend_dev_t> devices;
|
||||
devices.clear();
|
||||
// RPC devices should always come first for performance reasons
|
||||
for (const std::string & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (dev) {
|
||||
devices.push_back(dev);
|
||||
} else {
|
||||
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
// FIXME: use llama.cpp device selection logic
|
||||
// add local GPU devices if any
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
switch (ggml_backend_dev_type(dev)) {
|
||||
case GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
// skip CPU backends since they are handled separately
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
devices.push_back(dev);
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||
// iGPUs are not used when there are RPC servers
|
||||
break;
|
||||
}
|
||||
}
|
||||
devices.push_back(nullptr);
|
||||
mparams.devices = devices.data();
|
||||
}
|
||||
if (!devices.empty()) {
|
||||
mparams.devices = const_cast<ggml_backend_dev_t *>(devices.data());
|
||||
}
|
||||
mparams.split_mode = split_mode;
|
||||
mparams.main_gpu = main_gpu;
|
||||
@@ -1029,8 +1119,9 @@ struct cmd_params_instance {
|
||||
|
||||
bool equal_mparams(const cmd_params_instance & other) const {
|
||||
return model == other.model && n_gpu_layers == other.n_gpu_layers && n_cpu_moe == other.n_cpu_moe &&
|
||||
rpc_servers_str == other.rpc_servers_str && split_mode == other.split_mode &&
|
||||
split_mode == other.split_mode &&
|
||||
main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split &&
|
||||
devices == other.devices &&
|
||||
vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
|
||||
}
|
||||
|
||||
@@ -1060,9 +1151,9 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
for (const auto & m : params.model)
|
||||
for (const auto & nl : params.n_gpu_layers)
|
||||
for (const auto & ncmoe : params.n_cpu_moe)
|
||||
for (const auto & rpc : params.rpc_servers)
|
||||
for (const auto & sm : params.split_mode)
|
||||
for (const auto & mg : params.main_gpu)
|
||||
for (const auto & devs : params.devices)
|
||||
for (const auto & ts : params.tensor_split)
|
||||
for (const auto & ot : params.tensor_buft_overrides)
|
||||
for (const auto & mmp : params.use_mmap)
|
||||
@@ -1098,11 +1189,11 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .devices = */ devs,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .tensor_buft_overrides = */ ot,
|
||||
/* .use_mmap = */ mmp,
|
||||
@@ -1131,11 +1222,11 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .devices = */ devs,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .tensor_buft_overrides = */ ot,
|
||||
/* .use_mmap = */ mmp,
|
||||
@@ -1164,11 +1255,11 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
/* .no_kv_offload= */ nkvo,
|
||||
/* .flash_attn = */ fa,
|
||||
/* .devices = */ devs,
|
||||
/* .tensor_split = */ ts,
|
||||
/* .tensor_buft_overrides = */ ot,
|
||||
/* .use_mmap = */ mmp,
|
||||
@@ -1206,6 +1297,7 @@ struct test {
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
bool flash_attn;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
std::vector<float> tensor_split;
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
bool use_mmap;
|
||||
@@ -1241,6 +1333,7 @@ struct test {
|
||||
main_gpu = inst.main_gpu;
|
||||
no_kv_offload = inst.no_kv_offload;
|
||||
flash_attn = inst.flash_attn;
|
||||
devices = inst.devices;
|
||||
tensor_split = inst.tensor_split;
|
||||
tensor_buft_overrides = inst.tensor_buft_overrides;
|
||||
use_mmap = inst.use_mmap;
|
||||
@@ -1287,14 +1380,14 @@ struct test {
|
||||
|
||||
static const std::vector<std::string> & get_fields() {
|
||||
static const std::vector<std::string> fields = {
|
||||
"build_commit", "build_number", "cpu_info", "gpu_info", "backends",
|
||||
"model_filename", "model_type", "model_size", "model_n_params", "n_batch",
|
||||
"n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll",
|
||||
"type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen",
|
||||
"n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts",
|
||||
"stddev_ts"
|
||||
"build_commit", "build_number", "cpu_info", "gpu_info", "backends",
|
||||
"model_filename", "model_type", "model_size", "model_n_params", "n_batch",
|
||||
"n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll",
|
||||
"type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "devices", "tensor_split",
|
||||
"tensor_buft_overrides", "use_mmap", "embeddings", "no_op_offload",
|
||||
"n_prompt", "n_gen", "n_depth", "test_time", "avg_ns",
|
||||
"stddev_ns", "avg_ts", "stddev_ts"
|
||||
};
|
||||
return fields;
|
||||
}
|
||||
@@ -1378,6 +1471,7 @@ struct test {
|
||||
std::to_string(main_gpu),
|
||||
std::to_string(no_kv_offload),
|
||||
std::to_string(flash_attn),
|
||||
devices_to_string(devices),
|
||||
tensor_split_str,
|
||||
tensor_buft_overrides_str,
|
||||
std::to_string(use_mmap),
|
||||
@@ -1559,6 +1653,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "flash_attn") {
|
||||
return 2;
|
||||
}
|
||||
if (field == "devices") {
|
||||
return -12;
|
||||
}
|
||||
if (field == "use_mmap") {
|
||||
return 4;
|
||||
}
|
||||
@@ -1602,6 +1699,9 @@ struct markdown_printer : public printer {
|
||||
if (field == "no_op_offload") {
|
||||
return "nopo";
|
||||
}
|
||||
if (field == "devices") {
|
||||
return "dev";
|
||||
}
|
||||
if (field == "tensor_split") {
|
||||
return "ts";
|
||||
}
|
||||
@@ -1661,6 +1761,9 @@ struct markdown_printer : public printer {
|
||||
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
|
||||
fields.emplace_back("flash_attn");
|
||||
}
|
||||
if (params.devices.size() > 1 || params.devices != cmd_params_defaults.devices) {
|
||||
fields.emplace_back("devices");
|
||||
}
|
||||
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
|
||||
fields.emplace_back("tensor_split");
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
set(TARGET llama-pull)
|
||||
add_executable(${TARGET} pull.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
@@ -0,0 +1,43 @@
|
||||
# llama-pull - Model Download Tool
|
||||
|
||||
A command-line tool for downloading AI models from HuggingFace and Docker Hub for use with llama.cpp.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Download from HuggingFace
|
||||
llama-pull -hf <user>/<model>[:<quant>]
|
||||
|
||||
# Download from Docker Hub
|
||||
llama-pull -dr [<repo>/]<model>[:<quant>]
|
||||
```
|
||||
|
||||
## Options
|
||||
|
||||
- `-hf, --hf-repo REPO` - Download model from HuggingFace repository
|
||||
- `-dr, --docker-repo REPO` - Download model from Docker Hub
|
||||
- `--hf-token TOKEN` - HuggingFace token for private repositories
|
||||
- `-h, --help` - Show help message
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Download a HuggingFace model
|
||||
llama-pull -hf microsoft/DialoGPT-medium
|
||||
|
||||
# Download a Docker model (ai/ repo is default)
|
||||
llama-pull -dr gemma3
|
||||
|
||||
# Download with specific quantization
|
||||
llama-pull -hf bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M
|
||||
```
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in the standard llama.cpp cache directory:
|
||||
- Linux/macOS: `~/.cache/llama.cpp/`
|
||||
- The models can then be used with other llama.cpp tools
|
||||
|
||||
## Requirements
|
||||
|
||||
- Built with `LLAMA_USE_CURL=ON` (default) for download functionality
|
||||
@@ -0,0 +1,84 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("Usage: %s [options]\n", argv[0]);
|
||||
LOG("\n");
|
||||
LOG("Download models from HuggingFace or Docker Hub\n");
|
||||
LOG("\n");
|
||||
LOG("Options:\n");
|
||||
LOG(" -h, --help show this help message and exit\n");
|
||||
LOG(" -hf, -hfr, --hf-repo REPO download model from HuggingFace repository\n");
|
||||
LOG(" format: <user>/<model>[:<quant>]\n");
|
||||
LOG(" example: microsoft/DialoGPT-medium\n");
|
||||
LOG(" -dr, --docker-repo REPO download model from Docker Hub\n");
|
||||
LOG(" format: [<repo>/]<model>[:<quant>]\n");
|
||||
LOG(" example: gemma3\n");
|
||||
LOG(" -o, --output PATH output path for downloaded model\n");
|
||||
LOG(" (default: cache directory)\n");
|
||||
LOG(" --hf-token TOKEN HuggingFace token for private repositories\n");
|
||||
LOG("\n");
|
||||
LOG("Examples:\n");
|
||||
LOG(" %s -hf microsoft/DialoGPT-medium\n", argv[0]);
|
||||
LOG(" %s -dr gemma3\n", argv[0]);
|
||||
LOG(" %s -hf microsoft/DialoGPT-medium -o ./my-model.gguf\n", argv[0]);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
// Set up argument parsing context
|
||||
auto ctx = common_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage);
|
||||
|
||||
// Parse command line arguments
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check if help was requested or no download option provided
|
||||
if (params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
||||
LOG_ERR("error: must specify either -hf <repo> or -dr <repo>\n");
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Both cannot be specified at the same time
|
||||
if (!params.model.hf_repo.empty() && !params.model.docker_repo.empty()) {
|
||||
LOG_ERR("error: cannot specify both -hf and -dr options\n");
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Initialize llama backend for download functionality
|
||||
llama_backend_init();
|
||||
|
||||
LOG_INF("llama-pull: downloading model...\n");
|
||||
|
||||
try {
|
||||
// Use the existing model handling logic which downloads the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
|
||||
if (llama_init.model != nullptr) {
|
||||
LOG_INF("Model downloaded and loaded successfully to: %s\n", params.model.path.c_str());
|
||||
|
||||
// We only want to download, not keep the model loaded
|
||||
// The download happens during common_init_from_params
|
||||
} else {
|
||||
LOG_ERR("Failed to download or load model\n");
|
||||
return 1;
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("Error: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Clean up
|
||||
llama_backend_free();
|
||||
return 0;
|
||||
}
|
||||
Binary file not shown.
@@ -4679,17 +4679,17 @@ int main(int argc, char ** argv) {
|
||||
json res_json = result->to_json();
|
||||
if (res_json.is_array()) {
|
||||
for (const auto & res : res_json) {
|
||||
if (!server_sent_event(sink, "data", res)) {
|
||||
if (!server_sent_event(sink, res)) {
|
||||
// sending failed (HTTP connection closed), cancel the generation
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return server_sent_event(sink, "data", res_json);
|
||||
return server_sent_event(sink, res_json);
|
||||
}
|
||||
}, [&](const json & error_data) {
|
||||
server_sent_event(sink, "error", error_data);
|
||||
server_sent_event(sink, json{{"error", error_data}});
|
||||
}, [&sink]() {
|
||||
// note: do not use req.is_connection_closed here because req is already destroyed
|
||||
return !sink.is_writable();
|
||||
|
||||
@@ -459,9 +459,9 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
||||
return out;
|
||||
}
|
||||
|
||||
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
|
||||
static bool server_sent_event(httplib::DataSink & sink, const json & data) {
|
||||
const std::string str =
|
||||
std::string(event) + ": " +
|
||||
"data: " +
|
||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to install pre-commit and post-commit hooks for webui
|
||||
# Pre-commit: formats code and builds, stashes unstaged changes
|
||||
# Pre-commit: formats, lints, checks, and builds code, stashes unstaged changes
|
||||
# Post-commit: automatically unstashes changes
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
@@ -44,6 +44,18 @@ if git diff --cached --name-only | grep -q "^tools/server/webui/"; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the lint command
|
||||
npm run lint
|
||||
|
||||
# Check if lint command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run lint failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the check command
|
||||
npm run check
|
||||
|
||||
@@ -112,7 +124,7 @@ if [ $? -eq 0 ]; then
|
||||
echo " Post-commit: $POST_COMMIT_HOOK"
|
||||
echo ""
|
||||
echo "The hooks will automatically:"
|
||||
echo " • Format and build webui code before commits"
|
||||
echo " • Format, lint, check, and build webui code before commits"
|
||||
echo " • Stash unstaged changes during the process"
|
||||
echo " • Restore your unstaged changes after the commit"
|
||||
echo ""
|
||||
|
||||
@@ -121,3 +121,15 @@
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
|
||||
@layer utilities {
|
||||
.scrollbar-hide {
|
||||
/* Hide scrollbar for Chrome, Safari and Opera */
|
||||
&::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
/* Hide scrollbar for IE, Edge and Firefox */
|
||||
-ms-overflow-style: none;
|
||||
scrollbar-width: none;
|
||||
}
|
||||
}
|
||||
|
||||
+139
-136
@@ -1,15 +1,20 @@
|
||||
<script lang="ts">
|
||||
import { Settings, Funnel, AlertTriangle, Brain, Cog, Monitor, Sun, Moon } from '@lucide/svelte';
|
||||
import { ChatSettingsFooter, ChatSettingsSection } from '$lib/components/app';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import {
|
||||
Settings,
|
||||
Funnel,
|
||||
AlertTriangle,
|
||||
Brain,
|
||||
Cog,
|
||||
Monitor,
|
||||
Sun,
|
||||
Moon,
|
||||
ChevronLeft,
|
||||
ChevronRight
|
||||
} from '@lucide/svelte';
|
||||
import { ChatSettingsFooter, ChatSettingsFields } from '$lib/components/app';
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import * as Select from '$lib/components/ui/select';
|
||||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { supportsVision } from '$lib/stores/server.svelte';
|
||||
import { SETTING_CONFIG_DEFAULT } from '$lib/constants/settings-config';
|
||||
import { config, updateMultipleConfig, resetConfig } from '$lib/stores/settings.svelte';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
@@ -224,12 +229,20 @@
|
||||
let localConfig: SettingsConfigType = $state({ ...config() });
|
||||
let originalTheme: string = $state('');
|
||||
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
|
||||
function handleThemeChange(newTheme: string) {
|
||||
localConfig.theme = newTheme;
|
||||
|
||||
setMode(newTheme as 'light' | 'dark' | 'system');
|
||||
}
|
||||
|
||||
function handleConfigChange(key: string, value: string | boolean) {
|
||||
localConfig[key] = value;
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
if (localConfig.theme !== originalTheme) {
|
||||
setMode(originalTheme as 'light' | 'dark' | 'system');
|
||||
@@ -298,18 +311,63 @@
|
||||
onOpenChange?.(false);
|
||||
}
|
||||
|
||||
function scrollToCenter(element: HTMLElement) {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const containerRect = scrollContainer.getBoundingClientRect();
|
||||
const elementRect = element.getBoundingClientRect();
|
||||
|
||||
const elementCenter = elementRect.left + elementRect.width / 2;
|
||||
const containerCenter = containerRect.left + containerRect.width / 2;
|
||||
const scrollOffset = elementCenter - containerCenter;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollOffset, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollLeft() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: -250, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollRight() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: 250, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function updateScrollButtons() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
|
||||
canScrollLeft = scrollLeft > 0;
|
||||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1; // -1 for rounding
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (open) {
|
||||
localConfig = { ...config() };
|
||||
originalTheme = config().theme as string;
|
||||
|
||||
setTimeout(updateScrollButtons, 100);
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer) {
|
||||
updateScrollButtons();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content class="flex h-[64vh] flex-col gap-0 p-0" style="max-width: 48rem;">
|
||||
<div class="flex flex-1 overflow-hidden">
|
||||
<div class="w-64 border-r border-border/30 p-6">
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100vh] flex-col gap-0 rounded-none p-0 md:h-[64vh] md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
>
|
||||
<div class="flex flex-1 flex-col overflow-hidden md:flex-row">
|
||||
<!-- Desktop Sidebar -->
|
||||
<div class="hidden w-64 border-r border-border/30 p-6 md:block">
|
||||
<nav class="space-y-1 py-2">
|
||||
<Dialog.Title class="mb-6 flex items-center gap-2">Settings</Dialog.Title>
|
||||
|
||||
@@ -329,134 +387,79 @@
|
||||
</nav>
|
||||
</div>
|
||||
|
||||
<ScrollArea class="flex-1">
|
||||
<div class="space-y-6 p-6">
|
||||
<ChatSettingsSection title={currentSection.title} Icon={currentSection.icon}>
|
||||
{#each currentSection.fields as field (field.key)}
|
||||
<div class="space-y-2">
|
||||
{#if field.type === 'input'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
<!-- Mobile Header with Horizontal Scrollable Menu -->
|
||||
<div class="flex flex-col md:hidden">
|
||||
<div class="border-b border-border/30 py-4">
|
||||
<Dialog.Title class="mb-6 flex items-center gap-2 px-4">Settings</Dialog.Title>
|
||||
|
||||
<Input
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
onchange={(e) => (localConfig[field.key] = e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class="max-w-md"
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
<!-- Horizontal Scrollable Category Menu with Navigation -->
|
||||
<div class="relative flex items-center" style="scroll-padding: 1rem;">
|
||||
<button
|
||||
class="absolute left-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<Textarea
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
onchange={(e) => (localConfig[field.key] = e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class="min-h-[100px] max-w-2xl"
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'select'}
|
||||
{@const selectedOption = field.options?.find(
|
||||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
opt.value === localConfig[field.key]
|
||||
)}
|
||||
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={localConfig[field.key]}
|
||||
onValueChange={(value) => {
|
||||
if (field.key === 'theme' && value) {
|
||||
handleThemeChange(value);
|
||||
} else {
|
||||
localConfig[field.key] = value;
|
||||
}
|
||||
<div
|
||||
class="scrollbar-hide overflow-x-auto py-2"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
<div class="flex min-w-max gap-2">
|
||||
{#each settingSections as section (section.title)}
|
||||
<button
|
||||
class="flex cursor-pointer items-center gap-2 rounded-lg px-3 py-2 text-sm whitespace-nowrap transition-colors first:ml-4 last:mr-4 hover:bg-accent {activeSection ===
|
||||
section.title
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: 'text-muted-foreground'}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
activeSection = section.title;
|
||||
scrollToCenter(e.currentTarget as HTMLElement);
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class="max-w-md">
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
<Select.Content>
|
||||
{#if field.options}
|
||||
{#each field.options as option (option.value)}
|
||||
<Select.Item value={option.value} label={option.label}>
|
||||
<div class="flex items-center gap-2">
|
||||
{#if option.icon}
|
||||
{@const IconComponent = option.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
{option.label}
|
||||
</div>
|
||||
</Select.Item>
|
||||
{/each}
|
||||
{/if}
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'checkbox'}
|
||||
{@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()}
|
||||
<div class="flex items-start space-x-3">
|
||||
<Checkbox
|
||||
id={field.key}
|
||||
checked={Boolean(localConfig[field.key])}
|
||||
disabled={isDisabled}
|
||||
onCheckedChange={(checked) => (localConfig[field.key] = checked)}
|
||||
class="mt-1"
|
||||
/>
|
||||
|
||||
<div class="space-y-1">
|
||||
<label
|
||||
for={field.key}
|
||||
class="cursor-pointer text-sm leading-none font-medium {isDisabled
|
||||
? 'text-muted-foreground'
|
||||
: ''}"
|
||||
>
|
||||
{field.label}
|
||||
</label>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{:else if field.key === 'pdfAsImage' && !supportsVision()}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
PDF-to-image processing requires a vision-capable model. PDFs will be
|
||||
processed as text.
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
<section.icon class="h-4 w-4 flex-shrink-0" />
|
||||
<span>{section.title}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
{/each}
|
||||
</ChatSettingsSection>
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute right-2 z-10 flex h-6 w-6 items-center justify-center rounded-full bg-muted shadow-md backdrop-blur-sm transition-opacity hover:bg-accent {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ScrollArea class="max-h-[calc(100vh-13.5rem)] flex-1">
|
||||
<div class="space-y-6 p-4 md:p-6">
|
||||
<div>
|
||||
<div class="mb-6 flex hidden items-center gap-2 border-b border-border/30 pb-6 md:flex">
|
||||
<currentSection.icon class="h-5 w-5" />
|
||||
|
||||
<h3 class="text-lg font-semibold">{currentSection.title}</h3>
|
||||
</div>
|
||||
|
||||
<div class="space-y-6">
|
||||
<ChatSettingsFields
|
||||
fields={currentSection.fields}
|
||||
{localConfig}
|
||||
onConfigChange={handleConfigChange}
|
||||
onThemeChange={handleThemeChange}
|
||||
isMobile={false}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-8 border-t pt-6">
|
||||
<p class="text-xs text-muted-foreground">
|
||||
@@ -467,6 +470,6 @@
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<ChatSettingsFooter onClose={handleClose} onReset={handleReset} onSave={handleSave} />
|
||||
<ChatSettingsFooter onReset={handleReset} onSave={handleSave} />
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
||||
+145
@@ -0,0 +1,145 @@
|
||||
<script lang="ts">
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import * as Select from '$lib/components/ui/select';
|
||||
import { Textarea } from '$lib/components/ui/textarea';
|
||||
import { SETTING_CONFIG_DEFAULT, SETTING_CONFIG_INFO } from '$lib/constants/settings-config';
|
||||
import { supportsVision } from '$lib/stores/server.svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
fields: SettingsFieldConfig[];
|
||||
localConfig: SettingsConfigType;
|
||||
onConfigChange: (key: string, value: string | boolean) => void;
|
||||
onThemeChange?: (theme: string) => void;
|
||||
isMobile?: boolean;
|
||||
}
|
||||
|
||||
let { fields, localConfig, onConfigChange, onThemeChange, isMobile = false }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#each fields as field (field.key)}
|
||||
<div class="space-y-2">
|
||||
{#if field.type === 'input'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
|
||||
<Input
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class={isMobile ? 'w-full' : 'max-w-md'}
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'textarea'}
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
|
||||
<Textarea
|
||||
id={field.key}
|
||||
value={String(localConfig[field.key] || '')}
|
||||
onchange={(e) => onConfigChange(field.key, e.currentTarget.value)}
|
||||
placeholder={`Default: ${SETTING_CONFIG_DEFAULT[field.key] || 'none'}`}
|
||||
class={isMobile ? 'min-h-[100px] w-full' : 'min-h-[100px] max-w-2xl'}
|
||||
/>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'select'}
|
||||
{@const selectedOption = field.options?.find(
|
||||
(opt: { value: string; label: string; icon?: Component }) =>
|
||||
opt.value === localConfig[field.key]
|
||||
)}
|
||||
|
||||
<Label for={field.key} class="block text-sm font-medium">
|
||||
{field.label}
|
||||
</Label>
|
||||
|
||||
<Select.Root
|
||||
type="single"
|
||||
value={localConfig[field.key]}
|
||||
onValueChange={(value) => {
|
||||
if (field.key === 'theme' && value && onThemeChange) {
|
||||
onThemeChange(value);
|
||||
} else {
|
||||
onConfigChange(field.key, value);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Select.Trigger class={isMobile ? 'w-full' : 'max-w-md'}>
|
||||
<div class="flex items-center gap-2">
|
||||
{#if selectedOption?.icon}
|
||||
{@const IconComponent = selectedOption.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
{selectedOption?.label || `Select ${field.label.toLowerCase()}`}
|
||||
</div>
|
||||
</Select.Trigger>
|
||||
<Select.Content>
|
||||
{#if field.options}
|
||||
{#each field.options as option (option.value)}
|
||||
<Select.Item value={option.value} label={option.label}>
|
||||
<div class="flex items-center gap-2">
|
||||
{#if option.icon}
|
||||
{@const IconComponent = option.icon}
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/if}
|
||||
{option.label}
|
||||
</div>
|
||||
</Select.Item>
|
||||
{/each}
|
||||
{/if}
|
||||
</Select.Content>
|
||||
</Select.Root>
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{/if}
|
||||
{:else if field.type === 'checkbox'}
|
||||
{@const isDisabled = field.key === 'pdfAsImage' && !supportsVision()}
|
||||
<div class="flex items-start space-x-3">
|
||||
<Checkbox
|
||||
id={field.key}
|
||||
checked={Boolean(localConfig[field.key])}
|
||||
disabled={isDisabled}
|
||||
onCheckedChange={(checked) => onConfigChange(field.key, checked)}
|
||||
class="mt-1"
|
||||
/>
|
||||
|
||||
<div class="space-y-1">
|
||||
<label
|
||||
for={field.key}
|
||||
class="cursor-pointer text-sm leading-none font-medium {isDisabled
|
||||
? 'text-muted-foreground'
|
||||
: ''}"
|
||||
>
|
||||
{field.label}
|
||||
</label>
|
||||
|
||||
{#if field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{field.help || SETTING_CONFIG_INFO[field.key]}
|
||||
</p>
|
||||
{:else if field.key === 'pdfAsImage' && !supportsVision()}
|
||||
<p class="text-xs text-muted-foreground">
|
||||
PDF-to-image processing requires a vision-capable model. PDFs will be processed as
|
||||
text.
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
+2
-11
@@ -2,16 +2,11 @@
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
onClose?: () => void;
|
||||
onReset?: () => void;
|
||||
onSave?: () => void;
|
||||
}
|
||||
|
||||
let { onClose, onReset, onSave }: Props = $props();
|
||||
|
||||
function handleClose() {
|
||||
onClose?.();
|
||||
}
|
||||
let { onReset, onSave }: Props = $props();
|
||||
|
||||
function handleReset() {
|
||||
onReset?.();
|
||||
@@ -25,9 +20,5 @@
|
||||
<div class="flex justify-between border-t border-border/30 p-6">
|
||||
<Button variant="outline" onclick={handleReset}>Reset to default</Button>
|
||||
|
||||
<div class="flex gap-2">
|
||||
<Button variant="outline" onclick={handleClose}>Close</Button>
|
||||
|
||||
<Button onclick={handleSave}>Save</Button>
|
||||
</div>
|
||||
<Button onclick={handleSave}>Save settings</Button>
|
||||
</div>
|
||||
|
||||
-23
@@ -1,23 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { Component, Snippet } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children: Snippet;
|
||||
title: string;
|
||||
Icon: Component;
|
||||
}
|
||||
|
||||
let { children, title, Icon }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div>
|
||||
<div class="mb-6 flex items-center gap-2 border-b border-border/30 pb-6">
|
||||
<Icon class="h-5 w-5" />
|
||||
|
||||
<h3 class="text-lg font-semibold">{title}</h3>
|
||||
</div>
|
||||
|
||||
<div class="space-y-6">
|
||||
{@render children()}
|
||||
</div>
|
||||
</div>
|
||||
@@ -22,8 +22,8 @@ export { default as ChatScreenHeader } from './chat/ChatScreen/ChatScreenHeader.
|
||||
export { default as ChatScreen } from './chat/ChatScreen/ChatScreen.svelte';
|
||||
|
||||
export { default as ChatSettingsDialog } from './chat/ChatSettings/ChatSettingsDialog.svelte';
|
||||
export { default as ChatSettingsSection } from './chat/ChatSettings/ChatSettingsSection.svelte';
|
||||
export { default as ChatSettingsFooter } from './chat/ChatSettings/ChatSettingsFooter.svelte';
|
||||
export { default as ChatSettingsFields } from './chat/ChatSettings/ChatSettingsFields.svelte';
|
||||
|
||||
export { default as ChatSidebar } from './chat/ChatSidebar/ChatSidebar.svelte';
|
||||
export { default as ChatSidebarConversationItem } from './chat/ChatSidebar/ChatSidebarConversationItem.svelte';
|
||||
|
||||
Reference in New Issue
Block a user