Compare commits

...

27 Commits

Author SHA1 Message Date
Xuan-Son Nguyen ee8dd5c658 server: move res_error/res_ok to static function (#17167) 2025-11-12 14:17:24 +01:00
Alberto Cabrera Pérez 1c398dc9ec ggml-cpu: handle 3d tensors in repack mat_mul (#17030)
* ggml-cpu: handle 3d tensors in repack mul_mat

* Removed unnecessary branch, removed need for <algorithm>

* Fixed dst_ptr pointer in chunk + clang_format

* GGML_ASSERT to check wdata within bounds

* Accidental ggml.h inclusion

* Improved GGML_ASSERT on wdata boundaries
2025-11-12 14:52:19 +02:00
Adrien Gallouët 52cf111b31 cmake : cleanup (#17199) 2025-11-12 14:48:30 +02:00
Adrien Gallouët 78010a0d52 cmake : move OpenSSL linking to vendor/cpp-httplib (#17177)
* cmake : move OpenSSL linking to vendor/cpp-httplib

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* bring back httplib 0.27.0

* add -DLLAMA_HTTPLIB

* update cmake config for visionos

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-11-12 12:32:50 +01:00
TecJesh 655cddd174 CANN: Add L2_NORM op support (#16856)
* update L2_NORM op support

* update L2_NORM op support

* remove extra whitespace
2025-11-12 15:11:42 +08:00
Neo Zhang Jianyu 5da7664960 [SYCL]fix ci crash about SSM_CONV (#17169)
* fix ci crash

* Update ggml-sycl.cpp

* Update ggml/src/ggml-sycl/ggml-sycl.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Zhang Jianyu <zhang.jianyu@outlook.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2025-11-12 14:44:29 +08:00
Raul Torres 23a46ce972 CANN: GGML_CANN_ACL_GRAPH works only USE_ACL_GRAPH enabled (#16861)
The documentation should state that `GGML_CANN_ACL_GRAPH` is only effective if `USE_ACL_GRAPH` was enabled at compilation time.
2025-11-12 14:37:52 +08:00
Max Krasnyansky c273d75375 hexagon: various Op fixes (#17135)
* hexagon: explicitly check for ops with zero nrows

llm_graph_context::build_inp_out_ids() can generate tensors with zero nrows.
Somehow other backends seems to handle this without obvious explicit checks.
In the hexagon case we need to check explicitly and skip them.

* hexagon: introduce fastdiv, fix test-backend-ops for ADD/SUB/MUL

Co-authored-by: chraac <chraac@gmail.com>

* hexagon: use fastdiv in ADD_ID

* hexagon: use ggml_op_is_empty and ggml_is_empty to check for NOPs

---------

Co-authored-by: chraac <chraac@gmail.com>
2025-11-11 15:25:04 -08:00
Eve 7d019cff74 disable rms norm mul rope for chips with no fp16 rte (#17134) 2025-11-11 12:53:30 -06:00
sudhiarm 3fe36c3238 ci: add Arm-hosted Graviton4 runner (#17021)
* ci: add Arm-hosted Graviton4 runner

* ci: add missing dependencies for graviton4 build

* ci: enable LFS checkout on graviton4

* ci: move git-lfs install to dependencies in Graviton4 workflow
2025-11-11 17:58:05 +02:00
Xuan-Son Nguyen 1d45b4228f vendor: split httplib to cpp/h files (#17150)
* vendor: split httplib to cpp/h files

* move defines

* include httplib if curl is not used

* add TODO

* fix build ios

* fix build visionos instead
2025-11-11 13:32:58 +01:00
ixgbe ca4844062b ggml-cpu : add RISC-V RVV (Zvfh) optimization for FP16 to FP32 conversion (#17161)
Signed-off-by: Wang Yang <yangwang@iscas.ac.cn>
2025-11-11 13:41:51 +02:00
duduta 73460f6278 ggml-cpu: templateify ggml_compute_forward_rope_f32 and _f16 (#16805)
* extract rotate_pairs logic from ggml_compute_forward_rope_f32

* templateify ggml_compute_forward_rope_f32 and _f16

* abort when rope type not supported, remove GLM from test-rope

* add imrope branch to switch

* add rope tests for perf

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-11 13:33:24 +02:00
Charles Xu 8c583242ad kleidiai: add optimized per-channel kernels for Q8_0 (#16993) 2025-11-11 13:20:31 +02:00
Mike Abbott 4a5b8aff40 cmake : add version to all shared object files (#17091)
When compiling llama.cpp in Yocto, it fails QA checks because the generated so files aren't versioned.  This applies a version to all generated so files, allowing the package to build without errors.
2025-11-11 13:19:50 +02:00
Nicolas B. Pierron d2d626938a Install rpc-server when GGML_RPC is ON. (#17149) 2025-11-11 10:53:59 +00:00
levkropp 2fc392ce35 convert : register UMT5Model architecture for T5 conversion (#17160)
Register UMT5Model as a supported architecture variant for T5 model conversion.
This allows the conversion to work for models downloaded with AutoModel.
2025-11-11 09:38:30 +01:00
lhez ece0f5c177 opencl: add fastdiv and use it in set_rows, ported from cuda (#17090)
* opencl: add fastdiv for mm q8_0

* opencl: use uint4 for fastdiv vals

* opencl: use fastdiv for set_rows

* opencl: do not use fastdiv for q8_0 mm
2025-11-10 15:00:13 -08:00
Sigbjørn Skjæret 7bef684118 models : move build_inp_out_ids outside loop (#17151)
* move build_inp_out_ids outside loop

* realign
2025-11-10 22:55:30 +01:00
Max Krasnyansky 395e286bc9 cpu: skip NOPs to avoid barriers (#17133)
* cpu: skip NOPs to avoid barriers

* cpu: use ggml_op_is_empty
2025-11-10 12:44:49 -08:00
Georgi Gerganov 13730c183b metal : cap threadgroups size of set_rows (#17146) 2025-11-10 21:33:35 +02:00
Adrien Gallouët 967eb4b2bf ggml-cpu : inspect -march and -mcpu to found the CPU (#16333)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2025-11-10 21:03:36 +02:00
Ruben Ortlam f117be185e vulkan: check glslc executable string (#17144) 2025-11-10 16:59:26 +01:00
Ruben Ortlam 85234a4b3a vulkan: fix validation issue introduced by #16868 (#17145) 2025-11-10 16:59:10 +01:00
Gabe Goodhart 0c74f32632 memory: Hybrid context shift (#17009)
* feat(memory): Only fail partial erasure of recurrent tail

The recurrent state is always assumed to be the state as of the last update
from the final token in the sequence. When doing a partial erasure, if the
range does not include the final token, the erasure can be considered a
success since any memory used for the sequence prior to the final token
(which is no memory) has been successfully removed.

There is one potential case that this doesn't address which is the pruning
of cache to remove sensitive data from the context. This wouldn't work for
attention cache partial removal (in the middle) either since the KV state
is linearly-dependent and states in later sequence positions would still be
based on the state from the sensitive data, even if that data is no longer
cached, so I don't think this is relevant, but it is worth noting that the
semantics of this change for a partial erasure in the middle of the cache
are essentially "my context is already compressed" and not "all trace of
the removed tokens has been removed."

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(main): Check the output of seq_rm for prefix matching

This prefix matching is explicitly attempting to remove the tokens at the
end of the sequence that don't match. This is the operation that can't be
performed on a recurrent cache due to the state being updated in place, so
if this removal fails, we need to clear the whole cache.

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(memory): Fix condition for partial erasure failure if p0 > pos

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: compilade <git@compilade.net>

* style: Fix extra parens

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix(main.cpp): Set n_matching_session_tokens to 0 on cache clear

https://github.com/ggml-org/llama.cpp/issues/16768
Branch: HybridContextShift-16768

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: compilade <git@compilade.net>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-10 17:14:23 +02:00
Georgi Gerganov c27efd2bd1 metal : enable tensor API for A19 (#17087) 2025-11-10 15:38:42 +02:00
fj-y-saito df70bedda7 arm64: add i8mm route with SVE ggml_vec_dot_q4_K_q8_K and ggml_vec_dot_q6_K_… (#15277)
* add i8mm route with SVE ggml_vec_dot_q4_K_q8_K and ggml_vec_dot_q6_K_q8_K

* Surround SVE function with compiler directive

* fix compile switch

* fix coding style

* ggml : fix indent

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-11-10 15:12:59 +02:00
50 changed files with 11058 additions and 9982 deletions
+2
View File
@@ -34,6 +34,7 @@
rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets,
enableCurl ? true,
useVulkan ? false,
useRpc ? false,
llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake
# It's necessary to consistently use backendStdenv when building with CUDA support,
@@ -175,6 +176,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
(cmakeBool "GGML_METAL" useMetalKit)
(cmakeBool "GGML_VULKAN" useVulkan)
(cmakeBool "GGML_STATIC" enableStatic)
(cmakeBool "GGML_RPC" useRpc)
]
++ optionals useCuda [
(
+47
View File
@@ -1651,3 +1651,50 @@ jobs:
run: |
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
ggml-ci-arm64-graviton4-kleidiai:
runs-on: ah-ubuntu_22_04-c8g_8x
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
run: |
set -euxo pipefail
sudo apt-get update
sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a \
apt-get install -y \
build-essential \
libcurl4-openssl-dev \
python3-venv \
gpg \
wget \
time \
git-lfs
git lfs install
# install the latest cmake
sudo install -d /usr/share/keyrings
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc \
| gpg --dearmor \
| sudo tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null
echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' \
| sudo tee /etc/apt/sources.list.d/kitware.list
sudo apt-get update
sudo apt-get install -y cmake
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
key: ggml-ci-arm64-graviton4-kleidiai
evict-old-files: 1d
- name: Test
id: ggml-ci
run: |
GG_BUILD_KLEIDIAI=1 \
GG_BUILD_EXTRA_TESTS_0=1 \
bash ./ci/run.sh ./tmp/results ./tmp/mnt
+4
View File
@@ -92,6 +92,7 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
option(LLAMA_HTTPLIB "llama: if libcurl is disabled, use httplib to download model from an URL" ON)
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" OFF)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
@@ -200,6 +201,9 @@ endif()
if (LLAMA_BUILD_COMMON)
add_subdirectory(common)
if (LLAMA_HTTPLIB)
add_subdirectory(vendor/cpp-httplib)
endif()
endif()
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
+4
View File
@@ -454,6 +454,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos --config Release -- -quiet
@@ -468,6 +470,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DLLAMA_CURL=OFF \
-DLLAMA_HTTPLIB=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-S .
cmake --build build-visionos-sim --config Release -- -quiet
+6 -1
View File
@@ -121,7 +121,12 @@ fi
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
echo ">>===== Enabling KleidiAI support"
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
CANDIDATES=(
"armv9-a+dotprod+i8mm+sve2"
"armv9-a+dotprod+i8mm"
"armv8.6-a+dotprod+i8mm"
"armv8.2-a+dotprod"
)
CPU=""
for cpu in "${CANDIDATES[@]}"; do
+6 -37
View File
@@ -79,10 +79,11 @@ if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
set(LLAMA_COMMON_EXTRA_LIBS build_info)
# Use curl to download model url
if (LLAMA_CURL)
# Use curl to download model url
find_package(CURL)
if (NOT CURL_FOUND)
message(FATAL_ERROR "Could NOT find CURL. Hint: to disable this feature, set -DLLAMA_CURL=OFF")
@@ -90,42 +91,10 @@ if (LLAMA_CURL)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
include_directories(${CURL_INCLUDE_DIRS})
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
endif()
if (LLAMA_OPENSSL)
find_package(OpenSSL)
if (OpenSSL_FOUND)
include(CheckCSourceCompiles)
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
check_c_source_compiles("
#include <openssl/opensslv.h>
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
# if OPENSSL_VERSION_NUMBER < 0x1010107f
# error bad version
# endif
#else
# if OPENSSL_VERSION_NUMBER < 0x30000000L
# error bad version
# endif
#endif
int main() { return 0; }
" OPENSSL_VERSION_SUPPORTED)
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
if (OPENSSL_VERSION_SUPPORTED)
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
find_library(SECURITY_FRAMEWORK Security REQUIRED)
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
endif()
endif()
else()
message(STATUS "OpenSSL not found, SSL support disabled")
endif()
elseif (LLAMA_HTTPLIB)
# otherwise, use cpp-httplib
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
endif()
if (LLAMA_LLGUIDANCE)
+47 -29
View File
@@ -20,7 +20,7 @@
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#else
#elif defined(LLAMA_USE_HTTPLIB)
#include "http.h"
#endif
@@ -467,7 +467,7 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
return { res_code, std::move(res_buffer) };
}
#else
#elif defined(LLAMA_USE_HTTPLIB)
static bool is_output_a_tty() {
#if defined(_WIN32)
@@ -713,6 +713,8 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
#endif // LLAMA_USE_CURL
#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
@@ -907,33 +909,6 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
return { hf_repo, ggufFile, mmprojFile };
}
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
const std::vector<common_file_info> files = fs_list_files(cache_dir);
for (const auto & file : files) {
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info;
model_info.manifest_path = file.path;
std::string fname = file.name;
string_replace_all(fname, ".json", ""); // remove extension
auto parts = string_split<std::string>(fname, '=');
if (parts.size() == 4) {
// expect format: manifest=<user>=<model>=<tag>=<other>
model_info.user = parts[1];
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
}
}
return models;
}
//
// Docker registry functions
//
@@ -1052,3 +1027,46 @@ std::string common_docker_resolve_model(const std::string & docker) {
throw;
}
}
#else
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
throw std::runtime_error("download functionality is not enabled in this build");
}
bool common_download_model(const common_params_model &, const std::string &, bool) {
throw std::runtime_error("download functionality is not enabled in this build");
}
std::string common_docker_resolve_model(const std::string &) {
throw std::runtime_error("download functionality is not enabled in this build");
}
#endif // LLAMA_USE_CURL || LLAMA_USE_HTTPLIB
std::vector<common_cached_model_info> common_list_cached_models() {
std::vector<common_cached_model_info> models;
const std::string cache_dir = fs_get_cache_directory();
const std::vector<common_file_info> files = fs_list_files(cache_dir);
for (const auto & file : files) {
if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
common_cached_model_info model_info;
model_info.manifest_path = file.path;
std::string fname = file.name;
string_replace_all(fname, ".json", ""); // remove extension
auto parts = string_split<std::string>(fname, '=');
if (parts.size() == 4) {
// expect format: manifest=<user>=<model>=<tag>=<other>
model_info.user = parts[1];
model_info.model = parts[2];
model_info.tag = parts[3];
} else {
// invalid format
continue;
}
model_info.size = 0; // TODO: get GGUF size, not manifest size
models.push_back(model_info);
}
}
return models;
}
+1
View File
@@ -7354,6 +7354,7 @@ class PLMModel(TextModel):
@ModelBase.register("T5ForConditionalGeneration")
@ModelBase.register("MT5ForConditionalGeneration")
@ModelBase.register("UMT5ForConditionalGeneration")
@ModelBase.register("UMT5Model")
class T5Model(TextModel):
model_arch = gguf.MODEL_ARCH.T5
+6 -1
View File
@@ -313,7 +313,12 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
### GGML_CANN_ACL_GRAPH
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. This option is only effective if `USE_ACL_GRAPH` was enabled at compilation time. To enable it, recompile using:
```sh
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release -DUSE_ACL_GRAPH=ON
cmake --build build --config release
```
### GGML_CANN_GRAPH_CACHE_CAPACITY
+16
View File
@@ -211,6 +211,11 @@ add_library(ggml-base
ggml-quants.h
gguf.cpp)
set_target_properties(ggml-base PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
target_include_directories(ggml-base PRIVATE .)
if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
@@ -220,6 +225,11 @@ add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)
set_target_properties(ggml PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
if (GGML_BACKEND_DIR)
if (NOT GGML_BACKEND_DL)
message(FATAL_ERROR "GGML_BACKEND_DIR requires GGML_BACKEND_DL")
@@ -259,6 +269,12 @@ function(ggml_add_backend_library backend)
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
endif()
# Set versioning properties for all backend libraries
set_target_properties(${backend} PROPERTIES
VERSION ${GGML_VERSION}
SOVERSION ${GGML_VERSION_MAJOR}
)
if(NOT GGML_AVAILABLE_BACKENDS)
set(GGML_AVAILABLE_BACKENDS "${backend}"
CACHE INTERNAL "List of backends for cmake package")
+29
View File
@@ -448,6 +448,35 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
}
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
aclTensor * acl_src = ggml_cann_create_tensor(src);
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
size_t type_size = ggml_type_size(src->type);
int64_t n_bytes = src->ne[3]* src->ne[2]* src->ne[1]* type_size;
ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes);
void * buffer = temp_buffer_allocator.get();
int64_t div_ne[] = {1, src->ne[1], src->ne[2], src->ne[3]};
size_t div_nb[GGML_MAX_DIMS];
div_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; ++i) {
div_nb[i] = div_nb[i - 1] * div_ne[i - 1];
}
aclTensor * acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS);
std::vector<int64_t> norm_dims = { 3 };
aclIntArray * dims_array = aclCreateIntArray(norm_dims.data(), norm_dims.size());
float p_value = 2.0f;
aclScalar * p_scalar = aclCreateScalar(&p_value, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src, p_scalar, dims_array, true, acl_div);
GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_div, acl_dst);
ggml_cann_release_resources(ctx, dims_array, p_scalar, acl_src, acl_dst, acl_div);
}
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
+24
View File
@@ -46,6 +46,7 @@
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_log.h>
#include <aclnnop/aclnn_sign.h>
#include <aclnnop/aclnn_norm.h>
#include "acl_tensor.h"
#include "common.h"
@@ -187,6 +188,29 @@ void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the L2 Normalization for a ggml tensor using the CANN
* backend.
*
* @details This function applies the L2 Normalization operation on the
* input tensor `src` and stores the result in the destination tensor
* `dst`. L2 Normalization scales the input tensor such that the
* L2 norm along the specified dimension equals 1. This operation
* is commonly used in neural networks for feature normalization
* and vector scaling.
* The operation is defined as:
* \f[
* \text{out} = \frac{x}{\sqrt{\sum{x^2}}}
* \f]
* The normalization is performed along the last dimension by default.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor where the normalized values will be stored.
* @attention The normalization is performed along the last dimension of the
* input tensor by default.
*/
void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the Group Normalization for a ggml tensor using the CANN
* backend.
+4
View File
@@ -1777,6 +1777,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cann_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cann_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cann_concat(ctx, dst);
break;
@@ -2515,6 +2518,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_L2_NORM:
case GGML_OP_DUP:
case GGML_OP_SUM:
case GGML_OP_IM2COL:
+34 -11
View File
@@ -126,25 +126,36 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
)
if (NOT ARM_MCPU_RESULT)
string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}")
string(REGEX MATCH "-march=[^ ']+" ARM_MARCH_FLAG "${ARM_MCPU}")
# on some old GCC we need to read -march=
if (ARM_MARCH_FLAG AND NOT "${ARM_MARCH_FLAG}" STREQUAL "-march=native")
set(ARM_NATIVE_FLAG "${ARM_MARCH_FLAG}")
elseif(ARM_MCPU_FLAG AND NOT "${ARM_MCPU_FLAG}" STREQUAL "-mcpu=native")
set(ARM_NATIVE_FLAG "${ARM_MCPU_FLAG}")
endif()
endif()
if ("${ARM_MCPU_FLAG}" STREQUAL "")
set(ARM_MCPU_FLAG -mcpu=native)
message(STATUS "ARM -mcpu not found, -mcpu=native will be used")
if ("${ARM_NATIVE_FLAG}" STREQUAL "")
set(ARM_NATIVE_FLAG -mcpu=native)
message(WARNING "ARM -march/-mcpu not found, -mcpu=native will be used")
else()
message(STATUS "ARM detected flags: ${ARM_NATIVE_FLAG}")
endif()
include(CheckCXXSourceRuns)
function(check_arm_feature tag code)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+${tag}")
check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
if (GGML_MACHINE_SUPPORTS_${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+${tag}" PARENT_SCOPE)
else()
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
set(CMAKE_REQUIRED_FLAGS "${ARM_NATIVE_FLAG}+no${tag}")
check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
if (GGML_MACHINE_SUPPORTS_no${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
set(ARM_NATIVE_FLAG_FIX "${ARM_NATIVE_FLAG_FIX}+no${tag}" PARENT_SCOPE)
endif()
endif()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
@@ -155,7 +166,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
list(APPEND ARCH_FLAGS "${ARM_NATIVE_FLAG}${ARM_NATIVE_FLAG_FIX}")
else()
if (GGML_CPU_ARM_ARCH)
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
@@ -579,6 +590,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
@@ -597,23 +609,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
if (NOT DOTPROD_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
endif()
if (NOT I8MM_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
endif()
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
+428 -26
View File
@@ -2044,6 +2044,26 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
}
#ifdef __ARM_FEATURE_SVE
static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t * vx_scales) {
const svbool_t pg_all = svptrue_pat_b32(SV_VL4);
const svbool_t pg_false = svpfalse_b(); // 0x0000
const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff
const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8);
svuint32_t vutmp_hi, vutmp_lo;
svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales);
vutmp_hi = svzip1_u32(vx01, vx01);
vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2);
vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f)));
const svuint32_t vx2 = svdup_u32(vx_scales[2]);
vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2)));
vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f));
svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo);
return vutmp;
}
#endif
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
#ifdef __ARM_FEATURE_MATMUL_INT8
@@ -2066,8 +2086,220 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
static const uint32_t kmask3 = 0x03030303;
uint32_t utmp[4];
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
const block_q4_K * GGML_RESTRICT vx0 = vx;
const block_q8_K * GGML_RESTRICT vy0 = vy;
const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx);
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
union {
uint32_t u32[8];
uint64_t u64[4];
} new_utmp;
svfloat32_t sumf1 = svdup_n_f32(0);
switch (vector_length) {
case 128:
{
svbool_t pg_false = svpfalse_b();
svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8);
svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false);
svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8);
svbool_t pg128_all = svptrue_pat_b8(SV_VL16);
for (int i = 0; i < nb; ++i) {
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1);
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0);
svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8);
svint16_t sum_tmp1 = svuzp1_s16(lo, hi);
svint16_t sum_tmp2 = svuzp2_s16(lo, hi);
svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
lo = svld1_s16(pg128_all, vy1[i].bsums + 0);
hi = svld1_s16(pg128_all, vy1[i].bsums + 8);
sum_tmp1 = svuzp1(lo, hi);
sum_tmp2 = svuzp2(lo, hi);
svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2);
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
svst2_u32(pg128_all, new_utmp.u32, decoded_scales);
svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0)))));
svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0)))));
svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0));
svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1));
svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2);
svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0));
svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1));
svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5);
svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6)));
svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8);
svint32_t svscales, sumi1, sumi2;
svint32_t acc_sumif1 = svdup_n_s32(0);
svint32_t acc_sumif2 = svdup_n_s32(0);
svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3,
q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3;
#pragma GCC unroll 1
for (int j = 0; j < QK_K/64; ++j) {
q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf));
q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf));
q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf));
q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
q8bytes_0_h = svld1_s8(pg128_all, q8_0);
q8bytes_1_h = svld1_s8(pg128_all, q8_1);
q8bytes_0_l = svld1_s8(pg128_all, q8_0+16);
q8bytes_1_l = svld1_s8(pg128_all, q8_1+16);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1);
q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4));
q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4));
q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4));
q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l)));
l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h)));
q8bytes_0_h = svld1_s8(pg128_all, q8_0+32);
q8bytes_1_h = svld1_s8(pg128_all, q8_1+32);
q8bytes_0_l = svld1_s8(pg128_all, q8_0+48);
q8bytes_1_l = svld1_s8(pg128_all, q8_1+48);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l)));
sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2);
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
}
sumf1 = svmla_f32_x(pg128_all,
svmla_f32_x(pg128_all,
sumf1,
svcvt_f32_x(pg128_all,
svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)),
svsuper_block_scales),
svdmins,
svcvt_f32_s32_x(pg128_all, svsumfs_tmp));
} //end of for nb
} // end of case 128
break;
case 256:
case 512:
{
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16);
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs;
const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs;
const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs;
const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs;
svint32_t svscales, sumi1, sumi2;
svint32_t acc_sumif1 = svdup_n_s32(0);
svint32_t acc_sumif2 = svdup_n_s32(0);
svint8_t l0, l1, l2, l3, r0, r1, r2, r3;
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d);
svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin)));
svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp));
svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1);
svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums));
svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2);
svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales);
svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales);
svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1);
svst2_u32(pg8_16, new_utmp.u32, decoded_scales);
svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums)));
svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]);
svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]);
svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0)));
svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1)));
svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0);
svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1);
svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1);
svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1);
#pragma GCC unroll 1
for (int j = 0; j < QK_K/64; ++j) {
svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf);
svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf);
svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4);
svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4);
l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1)));
l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3)));
svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0);
svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1);
svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32);
svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32);
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3)));
sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24));
acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1);
sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3);
svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24));
acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2);
q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64;
}
svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2);
svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4);
acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif);
sumf1 = svmla_f32_x(pg32_4,
svmla_f32_x(pg32_4,
sumf1,
svcvt_f32_x(pg32_4, acc_sumif),
svsuper_block_scales),
svdmins,
svsumfs_tmp);
} // end of for nb
} // end of case 256-512
break;
default:
assert(false && "Unsupported vector length");
break;
}
svst1_f32(pg32_2, s, sumf1);
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8)));
return;
}
#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q4_K * GGML_RESTRICT x0 = x;
const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
@@ -2235,7 +2467,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const uint8_t * GGML_RESTRICT q4 = x[i].qs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
@@ -2480,7 +2711,201 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#if defined(__ARM_FEATURE_MATMUL_INT8)
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
#endif
#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2);
svfloat32_t sum = svdup_n_f32(0);
const block_q6_K * GGML_RESTRICT vx0 = vx;
const block_q8_K * GGML_RESTRICT vy0 = vy;
const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx);
const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by);
switch (vector_length) {
case 128:
{
const svbool_t pg128_all = svptrue_pat_b8(SV_ALL);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d));
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d);
// process q8sum summation 128 bit route
const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums);
const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8);
const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums);
const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8);
const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0);
const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0)));
const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1)));
const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1);
const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0)));
const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1)));
const svint64_t prod = svdup_n_s64(0);
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02));
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12));
svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2);
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02));
svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12));
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5);
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6)));
svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8);
// process mmla
svint8_t l0, l1, r0, r1;
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
for (int k = 0; k < 8; ++k) {
svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2));
svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2));
svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4));
svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4));
const int ql_pos = (k/4)*4;
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4);
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4);
const int qh_pos = (k/2)*2;
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos);
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos);
svint8_t q6bytes_0, q6bytes_1;
if (qh_pos <= 4) {
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
} else {
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4))));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4))));
}
svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8));
svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale);
}
qh0 += 32; qh1 += 32;
ql0 += 64; ql1 += 64;
q80 += 128; q81 += 128;
scale0 += 8; scale1 += 8;
}
sum = svmla_f32_x(pg128_all, sum,
svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp,
svisum_mins, svdup_n_s32(-32))),
svsuper_block_scales);
}
} // end of case 128
break;
case 256:
case 512:
{
const svbool_t pg256_all = svptrue_pat_b8(SV_ALL);
const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4);
for (int i = 0; i < nb; ++i) {
const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql;
const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh;
const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql;
const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh;
const int8_t * GGML_RESTRICT q80 = vy0[i].qs;
const int8_t * GGML_RESTRICT q81 = vy1[i].qs;
const int8_t * GGML_RESTRICT scale0 = vx0[i].scales;
const int8_t * GGML_RESTRICT scale1 = vx1[i].scales;
svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d)));
svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)));
svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp));
svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d);
// process q8sum summation 256 bit route
const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums);
const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums);
const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0));
const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1));
const svint64_t prod = svdup_n_s64(0);
svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0));
svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1));
svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0));
svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1));
svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2);
svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4);
svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6)));
svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8);
svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16));
svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10);
// process mmla
svint8_t l0, l1, r0, r1;
svint32_t isum_tmp = svdup_n_s32(0);
for (int j = 0; j < QK_K/128; ++j) {
for (int k = 0; k < 8; k+=2) { // process 2 block
svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0);
svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1);
svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2));
svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2));
const int ql_pos = (k/4)*4;
svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4);
svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4);
const int qh_pos = (k/2)*2;
svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos);
svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos);
svint8_t q6bytes_0, q6bytes_1;
if (qh_pos <= 4) {
q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos)));
q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos)));
} else {
q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4))));
q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4))));
}
svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2));
svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2));
l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1)));
r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1)));
svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k]));
svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1]));
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0);
isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1);
}
qh0 += 32; qh1 += 32;
ql0 += 64; ql1 += 64;
q80 += 128; q81 += 128;
scale0 += 8; scale1 += 8;
} // end of for
svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4);
isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp);
sum = svmla_f32_x(pg32_4, sum,
svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp,
svisum_mins, svdup_n_s32(-32))),
svsuper_block_scales);
}
} // end of case 256
break;
default:
assert(false && "Unsupported vector length");
break;
} // end of switch
svst1_f32(pg32_2, s, sum);
svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8)));
return;
}
#elif defined(__ARM_FEATURE_MATMUL_INT8)
if (nrc == 2) {
const block_q6_K * GGML_RESTRICT x0 = x;
const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
@@ -2594,27 +3019,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
// adjust bias, apply superblock scale
{
int32_t bias[4];
#ifdef __ARM_FEATURE_SVE
const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
const svint64_t zero = svdup_n_s64(0);
bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
#else
// NEON doesn't support int16 dot product, fallback to separated mul and add
const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
@@ -2646,7 +3050,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
bias[3] = vaddvq_s32(prod);
#endif
const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
const float32x4_t superblock_scale = {
@@ -2672,7 +3075,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
#ifdef __ARM_FEATURE_SVE
const int vector_length = ggml_cpu_get_sve_cnt()*8;
float sum = 0;
svuint8_t m4b = svdup_n_u8(0xf);
svint32_t vzero = svdup_n_s32(0);
+28 -16
View File
@@ -1807,22 +1807,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_cont(params, tensor);
} break;
case GGML_OP_RESHAPE:
{
ggml_compute_forward_reshape(params, tensor);
} break;
case GGML_OP_VIEW:
{
ggml_compute_forward_view(params, tensor);
} break;
case GGML_OP_PERMUTE:
{
ggml_compute_forward_permute(params, tensor);
} break;
case GGML_OP_TRANSPOSE:
{
ggml_compute_forward_transpose(params, tensor);
} break;
case GGML_OP_GET_ROWS:
{
ggml_compute_forward_get_rows(params, tensor);
@@ -2042,6 +2026,22 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_RESHAPE:
{
// nop
} break;
case GGML_OP_PERMUTE:
{
// nop
} break;
case GGML_OP_VIEW:
{
// nop
} break;
case GGML_OP_TRANSPOSE:
{
// nop
} break;
case GGML_OP_COUNT:
{
GGML_ABORT("fatal error");
@@ -2884,6 +2884,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];
if (ggml_op_is_empty(node->op)) {
// skip NOPs
continue;
}
ggml_compute_forward(&params, node);
if (state->ith == 0 && cplan->abort_callback &&
@@ -3269,6 +3274,13 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec);
}
#elif defined(__riscv_zvfh)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e16m1(n - i);
vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
__riscv_vse32_v_f32m2(&y[i], vy, vl);
}
#endif
for (; i < n; ++i) {
+283
View File
@@ -4,6 +4,7 @@
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
@@ -11,20 +12,31 @@
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai_common.h"
#include "simd-mappings.h"
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, bl, mr, kr, sr);
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(n, k, nr, kr, bl);
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr,
static_cast<const int8_t*>(rhs),
static_cast<const float*>(bias),
static_cast<const float*>(scale),
rhs_packed, extra_bytes,
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
GGML_UNUSED(kr);
}
static void dequantize_row_qsi8cxp(
const void *packed_data,
int32_t row_idx,
int64_t k,
float *out,
size_t nr,
size_t packed_row_stride,
size_t kr,
size_t bl,
size_t num_bytes_multiplier
) {
GGML_UNUSED(bl);
GGML_UNUSED(num_bytes_multiplier);
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
const size_t group_idx = row_idx / nr;
const size_t row_in_group = row_idx % nr;
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
const size_t num_blocks = k_internal / kr;
for (size_t block = 0; block < num_blocks; ++block) {
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
for (size_t i = 0; i < kr; ++i) {
const size_t k_idx = block * kr + i;
if (k_idx < (size_t) k) {
out[k_idx] = static_cast<float>(block_ptr[i]);
}
}
}
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
GGML_UNUSED(sums_ptr);
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
const float scale = scale_ptr[row_in_group];
if (scale == 0.0f) {
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] = 0.0f;
}
return;
}
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] *= scale;
}
}
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#endif
};
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
#if defined(__ARM_FEATURE_SME)
{
/* SME GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* SME GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* I8MM GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* I8MM GEMV (dotprod fallback) */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* DOTPROD GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
ggml_kleidiai_kernels * kernel = nullptr;
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
break;
}
}
if (!kernel) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels_q8[i];
break;
}
}
}
#endif
}
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
return kernels;
}
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
kernels = &gemm_gemv_kernels_q8[i];
break;
}
}
#endif
return kernels;
}
+1
View File
@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
+239 -38
View File
@@ -5,10 +5,13 @@
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#include <vector>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
@@ -38,8 +41,9 @@
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels;
} static ctx = { CPU_FEATURE_NONE, NULL };
ggml_kleidiai_kernels * kernels_q4;
ggml_kleidiai_kernels * kernels_q8;
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
static const char* cpu_feature_to_string(cpu_feature f) {
switch (f) {
@@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
if (sme_enabled != 0) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
#ifndef NDEBUG
if (ctx.kernels) {
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
if (ctx.kernels_q4) {
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
}
if (ctx.kernels_q8) {
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
}
#endif
}
@@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
@@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_q8_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_fp16(params, dst);
}
} else if (dst->op == GGML_OP_GET_ROWS) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_get_rows(params, dst);
}
}
@@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
if (!ctx.kernels) {
return false;
}
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
kernel_info * kernel = &ctx.kernels->gemm;
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
if (!kernels) {
return false;
}
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
return false;
}
const int ith = params->ith;
const int nth_raw = params->nth;
const int nth = nth_raw > 0 ? nth_raw : 1;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = 0;
if (n_start < n) {
n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
}
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if (m_start < m) {
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
if (n_to_process > 0) {
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
}
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
size_t num_bytes_multiplier = 0;
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return false;
}
kernels = ctx.kernels_q4;
block_len = QK4_0;
num_bytes_multiplier = sizeof(uint16_t);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return false;
}
kernels = ctx.kernels_q8;
block_len = QK8_0;
num_bytes_multiplier = sizeof(float);
} else {
return false;
}
rhs_packing_info * rhs_info = &kernels->rhs_info;
kernel_info * kernel = &kernels->gemm;
if (!rhs_info->to_float || !kernel->get_nr) {
return false;
}
@@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t block_rows = kernel->get_nr();
const size_t kr = kernel->get_kr();
const size_t num_bytes_multiplier = sizeof(uint16_t);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
const int ith = params->ith;
const int nth = params->nth;
@@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
float *out = (float *)((char *)dst->data + i * nb1);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
}
return true;
@@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
size_t nr = ctx.kernels->gemm.get_nr();
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
if (tensor->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return -1;
}
size_t nr = ctx.kernels_q4->gemm.get_nr();
size_t kr = ctx.kernels_q4->gemm.get_kr();
size_t sr = ctx.kernels_q4->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
static_cast<const uint8_t *>(data),
nullptr, nullptr, tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return -1;
}
const size_t row_stride = tensor->nb[1];
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
std::vector<int8_t> qdata(n * k, 0);
std::vector<float> scales(n, 0.0f);
for (size_t row = 0; row < n; ++row) {
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
static_cast<const uint8_t *>(data) + row * row_stride);
float max_abs = 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
max_abs = std::max(max_abs, std::fabs(value));
}
}
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
scales[row] = scale;
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
q = std::clamp(q, -127, 127);
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
}
}
}
size_t nr = ctx.kernels_q8->gemm.get_nr();
size_t kr = ctx.kernels_q8->gemm.get_kr();
size_t sr = ctx.kernels_q8->gemm.get_sr();
struct kai_rhs_pack_qsi8cx_params params;
params.lhs_zero_point = 1;
params.scale_multiplier = 1.0f;
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
qdata.data(), nullptr, scales.data(),
tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
}
return 0;
GGML_UNUSED(data_size);
return -1;
}
};
@@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
const size_t nr = ctx.kernels->gemm.get_nr();
const size_t kr = ctx.kernels->gemm.get_kr();
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
GGML_UNUSED(buft);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
if (tensor->type == GGML_TYPE_Q4_0) {
GGML_ASSERT(ctx.kernels_q4);
kernels = ctx.kernels_q4;
block_len = QK4_0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
GGML_ASSERT(ctx.kernels_q8);
kernels = ctx.kernels_q8;
block_len = QK8_0;
} else {
return 0;
}
const size_t nr = kernels->gemm.get_nr();
const size_t kr = kernels->gemm.get_kr();
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
const size_t raw = ggml_nbytes(tensor);
return packed > raw ? packed : raw;
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
+54 -303
View File
@@ -4455,46 +4455,6 @@ void ggml_compute_forward_cont(
ggml_compute_forward_dup(params, dst);
}
// ggml_compute_forward_reshape
void ggml_compute_forward_reshape(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_view
void ggml_compute_forward_view(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_permute
void ggml_compute_forward_permute(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_transpose
void ggml_compute_forward_transpose(
const ggml_compute_params * params,
ggml_tensor * dst) {
// NOP
GGML_UNUSED(params);
GGML_UNUSED(dst);
}
// ggml_compute_forward_get_rows
static void ggml_compute_forward_get_rows_q(
@@ -5543,7 +5503,28 @@ static void ggml_mrope_cache_init(
}
}
static void ggml_compute_forward_rope_f32(
template<typename T>
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
for (int64_t i0 = 0; i0 < n; i0 += 2) {
const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const T * const src = src_data + ic;
T * dst = dst_data + ic;
const float x0 = type_conversion_table<T>::to_f32(src[0]);
const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
}
}
template<typename T> //float or ggml_fp16_t
static void ggml_compute_forward_rope_flt(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {
@@ -5552,6 +5533,9 @@ static void ggml_compute_forward_rope_f32(
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_I32);
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
@@ -5574,7 +5558,8 @@ static void ggml_compute_forward_rope_f32(
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb0 == nb00);
GGML_ASSERT(nb0 == sizeof(T));
const int ith = params->ith;
const int nth = params->nth;
@@ -5599,12 +5584,11 @@ static void ggml_compute_forward_rope_f32(
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
if (mrope_used) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
@@ -5630,7 +5614,7 @@ static void ggml_compute_forward_rope_f32(
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
if (!mrope_used) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
@@ -5648,269 +5632,36 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (is_neox || is_mrope) {
if (is_vision){
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
switch (mode) {
case GGML_ROPE_TYPE_NORMAL:
rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
break;
case GGML_ROPE_TYPE_NEOX:
case GGML_ROPE_TYPE_MROPE:
case GGML_ROPE_TYPE_IMROPE:
rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
break;
case GGML_ROPE_TYPE_VISION:
rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
break;
default:
GGML_ABORT("rope type not supported");
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = src[0];
const float x1 = src[n_dims];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
}
} else {
if (!is_vision) {
// fill the remain channels with data from src tensor
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
}
}
}
// TODO: deduplicate f16/f32 code
static void ggml_compute_forward_rope_f16(
const ggml_compute_params * params,
ggml_tensor * dst,
const bool forward) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
int sections[4];
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
GGML_TENSOR_UNARY_OP_LOCALS
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(dst);
GGML_ASSERT(n_dims <= ne0);
GGML_ASSERT(n_dims % 2 == 0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// row index used to determine which thread to use
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) {
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
}
if (is_vision) {
GGML_ASSERT(n_dims == ne0/2);
}
const float * freq_factors = NULL;
if (src2 != NULL) {
GGML_ASSERT(src2->type == GGML_TYPE_F32);
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
freq_factors = (const float *) src2->data;
}
// backward process uses inverse rotation by cos and sin.
// cos and sin build a rotation matrix, where the inverse is the transpose.
// this essentially just switches the sign of sin.
const float sin_sign = forward ? 1.0f : -1.0f;
const int32_t * pos = (const int32_t *) src1->data;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
if (!is_mrope) {
const int64_t p = pos[i2];
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
else {
const int64_t p_t = pos[i2];
const int64_t p_h = pos[i2 + ne2];
const int64_t p_w = pos[i2 + ne2 * 2];
const int64_t p_e = pos[i2 + ne2 * 3];
ggml_mrope_cache_init(
p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
}
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (ir++ < ir0) continue;
if (ir > ir1) break;
if (is_neox || is_mrope) {
if (is_vision) {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
} else {
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
if (is_vision) {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const int64_t ic = i0/2;
const float cos_theta = cache[i0 + 0];
const float sin_theta = cache[i0 + 1];
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
} //attn-heads
}
}
}
@@ -5924,11 +5675,11 @@ void ggml_compute_forward_rope(
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, dst, true);
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, dst, true);
ggml_compute_forward_rope_flt<float>(params, dst, true);
} break;
default:
{
@@ -5948,11 +5699,11 @@ void ggml_compute_forward_rope_back(
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_rope_f16(params, dst, false);
ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_rope_f32(params, dst, false);
ggml_compute_forward_rope_flt<float>(params, dst, false);
} break;
default:
{
-4
View File
@@ -51,10 +51,6 @@ void ggml_compute_forward_scale(const struct ggml_compute_params * params, struc
void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+90 -42
View File
@@ -1600,29 +1600,52 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
return false;
}
void forward_mul_mat_one_chunk(ggml_compute_params * params, ggml_tensor * op, int64_t src0_start, int64_t src0_end) {
void forward_mul_mat_one_chunk(ggml_compute_params * params,
ggml_tensor * op,
int64_t src0_start,
int64_t src0_end,
int64_t src1_start,
int64_t src1_end) {
const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];
ggml_tensor * dst = op;
GGML_TENSOR_BINARY_OP_LOCALS
const void * src1_wdata = params->wdata;
const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
GGML_ASSERT(ne03 == 1 && ne13 == 1);
GGML_ASSERT(ne12 % ne02 == 0);
const int64_t r2 = ne12 / ne02;
const int64_t i12 = src1_start / ne1;
const int64_t i11 = src1_start - i12 * ne1;
// Determine batch index
const int64_t i02 = i12 / r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const char * src0_ptr = (const char *) src0->data + i02 * nb02;
const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
const int64_t nrows = src1_end - src1_start;
const int64_t ncols = src0_end - src0_start;
GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
if (nrows > 3) {
gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
src0_ptr + src0_start * nb01, src1_ptr,
nrows - (nrows % 4), ncols);
}
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
(float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
ne01, src0_ptr + src0_start * nb01,
src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
}
}
@@ -1647,6 +1670,12 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// TODO: General batched mul mat for 4D tensors
// Currently only supports 3D tensors
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
GGML_ASSERT(ne3 == 1);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
@@ -1654,47 +1683,60 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
char * wdata = static_cast<char *>(params->wdata);
const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
const size_t nbw2 = nbw1 * ne11;
assert(params->wsize >= nbw1 * ne11);
assert(params->wsize >= nbw2 * ne12);
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
int64_t i11_processed = 0;
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10);
}
for (int64_t i12 = 0; i12 < ne12; i12++) {
char * data_ptr = (char *) src1->data + i12 * nb12;
char * wdata_ptr = wdata + i12 * nbw2;
i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
}
const int64_t i11_processed = ne11 - ne11 % 4;
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
}
}
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int64_t nr = ggml_nrows(op->src[0]);
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
const int64_t nr0 = ggml_nrows(op->src[0]);
const int64_t nr1 = ne1 * ne2 * ne3;
int nth_scaled = nth * 4;
int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
// avoid too small chunks for narrow src1
int64_t chunk_size1 = MAX(16, (nr1 + nth - 1) / nth);
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1;
// Ensure minimum chunk size to avoid alignment issues with high thread counts
// Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
const int64_t min_chunk_size = NB_COLS;
if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) {
nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
}
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
if (nth == 1 || nchunk0 * nchunk1 < nth || disable_chunking) {
nchunk0 = nr0 > nr1 ? nth : 1;
nchunk1 = nr0 > nr1 ? 1 : nth;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size;
if (nchunk > max_nchunk) {
nchunk = max_nchunk;
}
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
nchunk0 = MIN(nchunk0, max_nchunk);
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
@@ -1706,23 +1748,29 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
while (current_chunk < nchunk) {
int64_t src0_start = (current_chunk * ne01) / nchunk;
int64_t src0_end = ((current_chunk + 1) * ne01) / nchunk;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
int64_t src0_start = dr0 * ith0;
int64_t src0_end = MIN(src0_start + dr0, nr0);
int64_t src1_start = dr1 * ith1;
int64_t src1_end = MIN(src1_start + dr1, nr1);
// Align boundaries to NB_COLS - round up to ensure all data is included
// The chunk size limiting above ensures chunks are large enough to prevent overlaps
src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
if (src0_end > ne01) {
src0_end = ne01;
}
src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
src0_end = MIN(src0_end, ne01);
// Make sure current plane is the last one before exiting
if (src0_start >= src0_end) {
break;
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
continue;
}
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end);
forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
+12 -25
View File
@@ -3156,26 +3156,17 @@ static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op
return (op0 && op0->src[1] == op1->src[1]);
}
static inline bool is_compute_op(ggml_tensor *node)
{
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
}
// scan the graph and figure out last compute op index
static inline int last_compute_op(ggml_cgraph * graph) {
int last;
int last = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
switch (node->op) {
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_RMS_NORM:
case GGML_OP_GLU:
case GGML_OP_ADD_ID:
last = i;
break;
default:
break;
if (is_compute_op(graph->nodes[i])) {
last = i;
}
}
@@ -3194,6 +3185,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
if (!is_compute_op(node)) {
continue;
}
uint32_t flags = 0;
// skip quantizer if src1 is reused
@@ -3245,14 +3240,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_rope(node, flags);
break;
// non-compute ops
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
+45 -29
View File
@@ -34,6 +34,11 @@ static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32,
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
#define htp_binary_preamble \
const struct htp_tensor * src0 = &octx->src0; \
const struct htp_tensor * src1 = &octx->src1; \
const struct htp_tensor * src2 = &octx->src2; \
struct htp_tensor * dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
@@ -62,16 +67,15 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
const uint32_t nb3 = dst->nb[3]; \
\
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
uint8_t * spad_data,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
enum htp_op op) {
static void binary_job_f32_per_thread(struct htp_ops_context * octx,
uint8_t * spad_data,
uint32_t nth,
uint32_t ith,
enum htp_op op) {
htp_binary_preamble;
const size_t src0_row_size = nb01;
@@ -107,16 +111,23 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
const uint32_t nr0 = ne00 / ne10;
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
const uint8_t * restrict src1_ptr = NULL;
const uint32_t ne02_ne01 = ne02 * ne01;
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
@@ -125,6 +136,7 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
}
}
const uint32_t nr0 = ne00 / ne10;
if (nr0 > 1) {
if ((1 == is_aligned) && (nr0 == ne00)) {
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
@@ -149,22 +161,17 @@ static void binary_job_f32_per_thread(const struct htp_tensor * src0,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
const struct htp_tensor * src2,
struct htp_tensor * dst,
uint8_t * spad_data,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
hvx_elemwise_f32_func func_HVX) {
static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
uint8_t * spad_data,
uint32_t nth,
uint32_t ith,
hvx_elemwise_f32_func func_HVX) {
htp_binary_preamble;
const size_t src0_row_size = nb01;
const size_t src1_row_size = nb11;
const size_t dst_row_size = nb1;
const uint32_t ne02_ne01 = ne02 * ne01;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@@ -187,10 +194,11 @@ static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint32_t ne02_ne01 = ne02 * ne01;
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
// src0 indices
const uint32_t i03 = ir / ne02_ne01;
const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
// src1 indices
@@ -234,13 +242,11 @@ static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * dat
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
octx->src0_nrows_per_thread, octx->op);
binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
break;
case HTP_OP_ADD_ID:
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
i, octx->src0_nrows_per_thread, hvx_add_f32);
binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
break;
default:
@@ -321,6 +327,16 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
}
+4 -4
View File
@@ -119,10 +119,10 @@ static const char * htp_type_name(uint32_t t) {
#define HTP_MAX_DIMS 4
struct htp_tensor {
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
uint32_t type; // Data type
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
uint32_t type; // Data type
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
};
#define HTP_MAX_OP_PARAMS 64
+11
View File
@@ -4,6 +4,7 @@
#include "htp-ctx.h"
#include "htp-msg.h"
#include "worker-pool.h"
#include "ops-utils.h"
#include <assert.h>
#include <stdint.h>
@@ -38,6 +39,16 @@ struct htp_ops_context {
uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;
struct fastdiv_values src0_div1; // fastdiv values for ne1
struct fastdiv_values src0_div2; // fastdiv values for ne2
struct fastdiv_values src0_div3; // fastdiv values for ne3
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src1_div1; // fastdiv values for ne1
struct fastdiv_values src1_div2; // fastdiv values for ne2
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
uint32_t flags;
};
+33
View File
@@ -31,6 +31,39 @@ static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
struct fastdiv_values {
uint32_t mp;
uint32_t l;
};
static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
struct fastdiv_values result = { 0, 0 };
// compute L = ceil(log2(d));
while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
++(result.l);
}
result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
return result;
}
static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
// Compute high 32 bits of n * mp
const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
// add n, apply bit shift
return (hi + n) >> vals->l;
}
static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
return n - fastdiv(n, vals) * d;
}
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
+4 -2
View File
@@ -564,8 +564,10 @@ ggml_metal_device_t ggml_metal_device_init(void) {
// TODO: try to update the tensor API kernels to at least match the simdgroup performance
if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL &&
![[dev->mtl_device name] containsString:@"M5"] &&
![[dev->mtl_device name] containsString:@"M6"]) {
GGML_LOG_WARN("%s: tensor API disabled for pre-M5 device\n", __func__);
![[dev->mtl_device name] containsString:@"M6"] &&
![[dev->mtl_device name] containsString:@"A19"] &&
![[dev->mtl_device name] containsString:@"A20"]) {
GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
dev->props.has_tensor = false;
}
+5
View File
@@ -1036,6 +1036,11 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
nth = std::min(nth, nk0);
if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nrptg = 1;
}
ggml_metal_kargs_set_rows args = {
/*.nk0 =*/ nk0,
/*.ne01 =*/ ne01,
+36 -2
View File
@@ -53,6 +53,37 @@
bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
struct fastdiv_vals {
uint32_t mp;
uint32_t L;
uint32_t d;
uint32_t pad;
};
static_assert(sizeof(fastdiv_vals) == 16, "fastdiv_vals size incorrect");
static fastdiv_vals init_fastdiv_values(uint64_t d_64) {
GGML_ASSERT(d_64 != 0);
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
uint32_t d = (uint32_t)d_64;
// compute L = ceil(log2(d));
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
// pack divisor as well to reduce error surface
return { mp, L, d, 0 };
}
enum GPU_FAMILY {
ADRENO,
INTEL,
@@ -4464,6 +4495,9 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ABORT("not implemented");
}
fastdiv_vals ne11_ = init_fastdiv_values(ne11);
fastdiv_vals ne12_ = init_fastdiv_values(ne12);
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
@@ -4474,8 +4508,8 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));
+35 -16
View File
@@ -1,5 +1,16 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
// v = { mp, L, d }
inline uint fastdiv(uint n, uint4 v) {
uint msbs;
msbs = mul_hi(n, v.s0);
return (msbs + n) >> v.s1;
}
inline uint fastmod(uint n, uint4 v) {
uint q = fastdiv(n, v);
return n - q * v.s2;
}
kernel void kernel_set_rows_f32_i64(
global char * src0,
ulong offset0,
@@ -11,8 +22,8 @@ kernel void kernel_set_rows_f32_i64(
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
@@ -33,8 +44,10 @@ kernel void kernel_set_rows_f32_i64(
return;
}
int i12 = i03%ne12;
int i11 = i02%ne11;
//int i12 = i03%ne12;
//int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -58,8 +71,8 @@ kernel void kernel_set_rows_f16_i64(
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
@@ -80,8 +93,10 @@ kernel void kernel_set_rows_f16_i64(
return;
}
int i12 = i03%ne12;
int i11 = i02%ne11;
//int i12 = i03%ne12;
//int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -105,8 +120,8 @@ kernel void kernel_set_rows_f32_i32(
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
@@ -127,8 +142,10 @@ kernel void kernel_set_rows_f32_i32(
return;
}
int i12 = i03%ne12;
int i11 = i02%ne11;
//int i12 = i03%ne12;
//int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -152,8 +169,8 @@ kernel void kernel_set_rows_f16_i32(
ulong nb01,
ulong nb02,
ulong nb03,
int ne11,
int ne12,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
@@ -174,8 +191,10 @@ kernel void kernel_set_rows_f16_i32(
return;
}
int i12 = i03%ne12;
int i11 = i02%ne11;
//int i12 = i03%ne12;
//int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
+1
View File
@@ -3933,6 +3933,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
break;
case GGML_OP_SSM_CONV:
ggml_sycl_ssm_conv(ctx, dst);
break;
case GGML_OP_ROLL:
ggml_sycl_roll(ctx, dst);
break;
+10 -4
View File
@@ -6831,7 +6831,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
uint64_t b_sz = 1;
if (enable_bias) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
@@ -6965,7 +6965,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
uint64_t b_sz = 1;
if (enable_bias) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
@@ -7101,7 +7101,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
uint64_t b_sz = 1;
if (enable_bias) {
const ggml_tensor * add = cgraph->nodes[node_idx + 1];
@@ -7676,7 +7676,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
vk_buffer d_B = d_D;
size_t b_buf_offset = 0;
uint64_t b_sz = 0;
uint64_t b_sz = 1;
if (enable_bias || enable_scale) {
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
@@ -12670,6 +12670,12 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
return false;
}
// conditions for pipeline creation
if (!(ctx->device->float_controls_rte_fp16 &&
sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
return false;
}
return true;
}
@@ -18,6 +18,7 @@
#include <algorithm>
#include <sys/stat.h>
#include <sys/types.h>
#include <filesystem>
#ifdef _WIN32
#define NOMINMAX
@@ -1080,6 +1081,11 @@ int main(int argc, char** argv) {
if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc
if (!std::filesystem::exists(GLSLC) || !std::filesystem::is_regular_file(GLSLC)) {
std::cerr << "Error: glslc not found at " << GLSLC << std::endl;
return EXIT_FAILURE;
}
}
if (args.find("--source") != args.end()) {
input_filepath = args["--source"]; // The shader source file to compile
+18 -1
View File
@@ -14,9 +14,26 @@ vendor = {
"https://github.com/mackron/miniaudio/raw/refs/tags/0.11.22/miniaudio.h": "vendor/miniaudio/miniaudio.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.20.1/httplib.h": "vendor/cpp-httplib/httplib.h",
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.27.0/httplib.h": "vendor/cpp-httplib/httplib.h",
}
for url, filename in vendor.items():
print(f"downloading {url} to {filename}") # noqa: NP100
urllib.request.urlretrieve(url, filename)
# split cpp/h files for httplib
# see: https://github.com/yhirose/cpp-httplib/blob/master/split.py
if 'httplib.h' in filename:
border = '// ----------------------------------------------------------------------------'
with open(filename, 'r') as f:
content = f.read()
header, implementation, footer = content.split(border, 2)
fname_cpp = filename.replace('.h', '.cpp')
with open(filename, 'w') as fh:
fh.write(header)
fh.write(footer)
with open(fname_cpp, 'w') as fc:
fc.write('#include "httplib.h"\n')
fc.write('namespace httplib {\n')
fc.write(implementation.replace('\ninline ', '\n'))
fc.write('} // namespace httplib\n')
+5
View File
@@ -132,6 +132,11 @@ add_library(llama
models/graph-context-mamba.cpp
)
set_target_properties(llama PROPERTIES
VERSION ${LLAMA_INSTALL_VERSION}
SOVERSION 0
)
target_include_directories(llama PRIVATE .)
target_include_directories(llama PUBLIC ../include)
target_compile_features (llama PRIVATE cxx_std_17) # don't bump
+4 -3
View File
@@ -151,7 +151,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max();
}
// models like Mamba or RWKV can't have a state partially erased
// models like Mamba or RWKV can't have a state partially erased at the end
// of the sequence because their state isn't preserved for previous tokens
if (seq_id >= (int64_t) size) {
// could be fatal
return false;
@@ -160,8 +161,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
+4 -5
View File
@@ -1,7 +1,5 @@
#include "models.h"
llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -19,6 +17,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
@@ -67,9 +67,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
+2 -1
View File
@@ -11,6 +11,8 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
@@ -69,7 +71,6 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
+16
View File
@@ -7603,6 +7603,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
}
for (bool fw : {true, false}) { // fw == forward
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (bool ff : {false, true}) { // freq_factors
for (float v : { 0, 1 }) {
test_cases.emplace_back(new test_rope(type, {128, 32, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 7B
test_cases.emplace_back(new test_rope(type, {128, 64, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 65B
test_cases.emplace_back(new test_rope(type, { 80, 32, 512, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm)
test_cases.emplace_back(new test_rope(type, { 64, 8, 512, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B)
test_cases.emplace_back(new test_rope(type, {128, 12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
}
}
}
}
std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
{ 8192, 1, 1, 1 },
{ 8192, 8192, 1, 1 },
+6 -5
View File
@@ -138,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
struct ggml_tensor * x;
// rope f32
for (int m = 0; m < 6; ++m) {
for (int m = 0; m < 5; ++m) {
const int ndims = 4;
const int64_t n_rot = 128;
@@ -153,7 +153,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
int mode = -1;
if (m < 3) {
if (m < 2) {
struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
@@ -163,8 +163,8 @@ int main(int /*argc*/, const char ** /*argv*/) {
((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
((int32_t *) p2->data)[i] = n_past_2 + i;
}
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
mode = m == 0 ? 0 : m == 1 ? 2 : 4;
// test mode 0, 2 (standard, GPT-NeoX)
mode = m == 0 ? GGML_ROPE_TYPE_NORMAL : GGML_ROPE_TYPE_NEOX;
// 100, 101, 102, ..., 172
r0 = ggml_rope(ctx0, x, p0, n_rot, mode);
@@ -180,7 +180,8 @@ int main(int /*argc*/, const char ** /*argv*/) {
struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
int sections[4] = {16, 24, 24, 0};
mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : (m == 4) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
mode = (m == 2) ? GGML_ROPE_TYPE_MROPE : (m == 3) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
for (int i = 0; i < ne[2]; ++i) {
for (int j = 0; j < 4; ++j) {
+5 -1
View File
@@ -354,7 +354,11 @@ int main(int argc, char ** argv) {
}
// remove any "future" tokens that we might have inherited from the previous session
llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
LOG_INF("%s: unable to resuse common prefix\n", __func__);
n_matching_session_tokens = 0;
llama_memory_seq_rm(mem, -1, -1, -1);
}
}
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
+5
View File
@@ -13,6 +13,11 @@ add_library(mtmd
mtmd-helper.h
)
set_target_properties(mtmd PROPERTIES
VERSION ${LLAMA_INSTALL_VERSION}
SOVERSION 0
)
target_link_libraries (mtmd PUBLIC ggml llama)
target_link_libraries (mtmd PRIVATE Threads::Threads)
target_include_directories(mtmd PUBLIC .)
+4
View File
@@ -2,3 +2,7 @@ set(TARGET rpc-server)
add_executable(${TARGET} rpc-server.cpp)
target_link_libraries(${TARGET} PRIVATE ggml)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} RUNTIME)
endif()
+5 -1
View File
@@ -7,6 +7,10 @@ if (MINGW)
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()
if (NOT LLAMA_HTTPLIB)
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
endif()
set(TARGET_SRCS
server.cpp
utils.hpp
@@ -33,7 +37,7 @@ install(TARGETS ${TARGET} RUNTIME)
target_include_directories(${TARGET} PRIVATE ../mtmd)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common mtmd cpp-httplib ${CMAKE_THREAD_LIBS_INIT})
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
+30 -31
View File
@@ -4432,6 +4432,17 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
SRV_DBG("response: %s\n", res.body.c_str());
}
static void res_error(httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
}
static void res_ok(httplib::Response & res, const json & data) {
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
res.status = 200;
}
std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
@@ -4501,19 +4512,7 @@ int main(int argc, char ** argv) {
svr->set_default_headers({{"Server", "llama.cpp"}});
svr->set_logger(log_server_request);
auto res_error = [](httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};
auto res_ok = [](httplib::Response & res, const json & data) {
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
res.status = 200;
};
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
std::string message;
try {
std::rethrow_exception(ep);
@@ -4532,7 +4531,7 @@ int main(int argc, char ** argv) {
}
});
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
}
@@ -4562,7 +4561,7 @@ int main(int argc, char ** argv) {
// Middlewares
//
auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
auto middleware_validate_api_key = [&params](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
@@ -4600,7 +4599,7 @@ int main(int argc, char ** argv) {
return false;
};
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split<std::string>(req.path, '.');
@@ -4788,7 +4787,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_save = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
@@ -4820,7 +4819,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_restore = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
@@ -4853,7 +4852,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
int task_id = ctx_server.queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
@@ -4876,7 +4875,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};
const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
const auto handle_slots_action = [&params, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) {
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return;
@@ -4905,7 +4904,7 @@ int main(int argc, char ** argv) {
}
};
const auto handle_props = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_props = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
json default_generation_settings_for_props;
{
@@ -4947,7 +4946,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
@@ -4960,7 +4959,7 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) {
bool has_mtmd = ctx_server.mctx != nullptr;
json data = {
{
@@ -4991,7 +4990,7 @@ int main(int argc, char ** argv) {
// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
const auto handle_completions_impl = [&ctx_server](
server_task_type type,
json & data,
const std::vector<raw_buffer> & files,
@@ -5139,7 +5138,7 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_COMPLETION);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
@@ -5238,7 +5237,7 @@ int main(int argc, char ** argv) {
};
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_chat_params_parse(
@@ -5248,7 +5247,7 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
const auto handle_models = [&params, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_models = [&params, &ctx_server, &state](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load();
json model_meta = nullptr;
if (current_state == SERVER_STATE_READY) {
@@ -5293,7 +5292,7 @@ int main(int argc, char ** argv) {
res_ok(res, models);
};
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
json tokens_response = json::array();
@@ -5334,7 +5333,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};
const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
std::string content;
@@ -5347,7 +5346,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
if (!ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
@@ -5457,7 +5456,7 @@ int main(int argc, char ** argv) {
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
-8
View File
@@ -9,14 +9,6 @@
#include "mtmd-helper.h"
#include "chat.h"
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
// increase backlog size to avoid connection resets for >> 1 slots
#define CPPHTTPLIB_LISTEN_BACKLOG 512
// increase max URI length to handle longer prompts in query string
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
// disable Nagle's algorithm
#define CPPHTTPLIB_TCP_NODELAY true
#include <cpp-httplib/httplib.h>
#define JSON_ASSERT GGML_ASSERT
+60
View File
@@ -0,0 +1,60 @@
set(TARGET cpp-httplib)
find_package(Threads REQUIRED)
add_library(${TARGET} STATIC httplib.cpp httplib.h)
if (NOT MSVC)
# disable warnings in 3rd party code
target_compile_options(${TARGET} PRIVATE -w)
endif()
target_link_libraries (${TARGET} PRIVATE Threads::Threads)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_compile_definitions(${TARGET} PRIVATE
# increase max payload length to allow use of larger context size
CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH=1048576
# increase backlog size to avoid connection resets for >> 1 slots
CPPHTTPLIB_LISTEN_BACKLOG=512
# increase max URI length to handle longer prompts in query string
CPPHTTPLIB_REQUEST_URI_MAX_LENGTH=32768
# disable Nagle's algorithm
CPPHTTPLIB_TCP_NODELAY=1
)
if (LLAMA_OPENSSL)
find_package(OpenSSL)
if (OpenSSL_FOUND)
include(CheckCSourceCompiles)
set(SAVED_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
set(CMAKE_REQUIRED_INCLUDES ${OPENSSL_INCLUDE_DIR})
check_c_source_compiles("
#include <openssl/opensslv.h>
#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER)
# if OPENSSL_VERSION_NUMBER < 0x1010107f
# error bad version
# endif
#else
# if OPENSSL_VERSION_NUMBER < 0x30000000L
# error bad version
# endif
#endif
int main() { return 0; }
" OPENSSL_VERSION_SUPPORTED)
set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES})
if (OPENSSL_VERSION_SUPPORTED)
message(STATUS "OpenSSL found: ${OPENSSL_VERSION}")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT)
target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto)
if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)
find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED)
find_library(SECURITY_FRAMEWORK Security REQUIRED)
target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK})
endif()
endif()
else()
message(STATUS "OpenSSL not found, SSL support disabled")
endif()
endif()
+9339
View File
File diff suppressed because it is too large Load Diff
-9336
View File
File diff suppressed because it is too large Load Diff