mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-02 18:47:43 +02:00
Compare commits
145 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b436e4e5e | |||
| a6d3e9a239 | |||
| 9c5d8dec37 | |||
| c76408dbb9 | |||
| c79698f28a | |||
| 45250db0f8 | |||
| dfac6caa40 | |||
| 327e2ca6f2 | |||
| 09788740f3 | |||
| 4e89ec67fa | |||
| 46a9a0656a | |||
| f2ac3ef57e | |||
| 12c719b3f1 | |||
| 5d67acd422 | |||
| d97dd299a0 | |||
| 2e23292cfe | |||
| 7babe5fb13 | |||
| 357b8e50f1 | |||
| 69788e0d23 | |||
| 198f79d6c3 | |||
| da348c9dfb | |||
| e6267a9359 | |||
| 2bf318fd2f | |||
| c78e682245 | |||
| c5897995a7 | |||
| 03fd9d3bb4 | |||
| 8004f3a8d1 | |||
| eacb4b67a2 | |||
| c0d0430340 | |||
| 3bb2fcc856 | |||
| 27326bfce1 | |||
| ad9f692f8f | |||
| 8a70973557 | |||
| e7f2f95c9a | |||
| b55dcdef5d | |||
| eeef3cfced | |||
| e99f1083a0 | |||
| 238856ec8f | |||
| ea003229d3 | |||
| d0061be838 | |||
| a569bda445 | |||
| e2f19b320f | |||
| 983559d24b | |||
| 2b089c7758 | |||
| afa6bfe4f7 | |||
| ae2d3f28a8 | |||
| ad8207af77 | |||
| 667b694278 | |||
| e48349a49d | |||
| ae46a61e41 | |||
| 65cede7c70 | |||
| 05fa625eac | |||
| d612901116 | |||
| cceb1b4e33 | |||
| d23a55997d | |||
| 5f28c53d11 | |||
| 4408494144 | |||
| 2ba9adc093 | |||
| cc45f2ada6 | |||
| d5dfc33027 | |||
| 267ba5a1d9 | |||
| ff4affb4c1 | |||
| 55d58599c8 | |||
| 1a8c700bfd | |||
| 27b93cbd15 | |||
| 6e67fd2144 | |||
| 9e118b97c4 | |||
| 57088276d4 | |||
| 341bc7d23c | |||
| 08e6d914b8 | |||
| 184c694f45 | |||
| 684b36101c | |||
| 3a00c98584 | |||
| 079feab9e3 | |||
| 01d8eaa28d | |||
| 1725e316c1 | |||
| b7742cf321 | |||
| badba89320 | |||
| baa12f3831 | |||
| 2d8015e8a4 | |||
| eb145c0753 | |||
| 6e473fb384 | |||
| c7db95f106 | |||
| 0d00ef65ed | |||
| 91ea5d67f2 | |||
| dbb023336b | |||
| 53aef25a88 | |||
| 2dec548094 | |||
| 0ccbfdef3e | |||
| 94a602db66 | |||
| 05a6f0e894 | |||
| b48e80f677 | |||
| 752584d5f5 | |||
| cc2aa81513 | |||
| 0e21991472 | |||
| b2ecc0cdb4 | |||
| 5065da554e | |||
| 5174d7206f | |||
| 43919b7f4f | |||
| 423cf0b26f | |||
| 33a56f90a6 | |||
| 25224c8021 | |||
| 2f5d8f8edc | |||
| bb96bfd361 | |||
| 0644baefde | |||
| 490eb96b88 | |||
| 3bb78133ab | |||
| 79cc0f2daf | |||
| 338085c69e | |||
| 4c61875bf8 | |||
| 4b385bfcf8 | |||
| f488429380 | |||
| 4d688f9ebb | |||
| ff599039a9 | |||
| f486ce9f30 | |||
| 38adc7d469 | |||
| 3b3a948134 | |||
| 6845f7f87f | |||
| fa16e517a3 | |||
| 313493de53 | |||
| b1ff83bbb0 | |||
| 4ae1b7517a | |||
| 4d3daf80f8 | |||
| 914dde72ba | |||
| 3136a849db | |||
| e463bbdf65 | |||
| 53de59f67d | |||
| 9ab072ebbe | |||
| ada90bf2ba | |||
| 0c1f39a9ae | |||
| 73cd5e1b97 | |||
| 8ee538ce73 | |||
| 6d95707827 | |||
| 89181c0b6d | |||
| ceaa89b786 | |||
| 2cce9fddb7 | |||
| 612db61886 | |||
| 57487a64c8 | |||
| fc0fe40049 | |||
| 9a96352729 | |||
| c03a5a46f0 | |||
| 6948adc90d | |||
| 854b09f0d7 | |||
| 66d403c480 | |||
| f0bfe54f55 |
@@ -41,7 +41,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -42,7 +42,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -17,7 +17,7 @@ jobs:
|
||||
|
||||
- name: Install komac
|
||||
run: |
|
||||
cargo binstall komac@2.11.2 -y
|
||||
cargo binstall komac@2.15.0 -y
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
|
||||
@@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
|
||||
|
||||
- Explicitly informing them that AI-generated pull requests are not accepted by the project
|
||||
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
|
||||
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
|
||||
- Providing useful links and pointers found throughout the codebase
|
||||
|
||||
Examples of valid questions:
|
||||
|
||||
+10
-12
@@ -112,15 +112,9 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
# deprecated
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
if (LLAMA_CURL)
|
||||
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
|
||||
endif()
|
||||
|
||||
# Required for relocatable CMake package
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||
@@ -148,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
|
||||
endif()
|
||||
|
||||
# transition helpers
|
||||
function (llama_option_depr TYPE OLD NEW)
|
||||
function (llama_option_depr TYPE OLD)
|
||||
if (${OLD})
|
||||
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
set(NEW "${ARGV2}")
|
||||
if(NEW)
|
||||
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
else()
|
||||
message(${TYPE} "${OLD} is deprecated and will be ignored")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
@@ -164,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
|
||||
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
||||
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
||||
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
||||
llama_option_depr(WARNING LLAMA_CURL)
|
||||
|
||||
include("cmake/license.cmake")
|
||||
license_add_file("llama.cpp" "LICENSE")
|
||||
@@ -197,9 +197,7 @@ add_subdirectory(src)
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
|
||||
1. Explicitly disclose the manner in which AI was employed.
|
||||
2. Perform a comprehensive manual review prior to submitting the pull request.
|
||||
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
|
||||
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
|
||||
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
|
||||
|
||||
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
|
||||
|
||||
|
||||
+1
-1
@@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
|
||||
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
|
||||
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
|
||||
|
||||
## Requirements
|
||||
|
||||
|
||||
+14
-18
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
|
||||
-DGGML_OPENMP=${GGML_OPENMP}
|
||||
)
|
||||
|
||||
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
check_required_tool() {
|
||||
local tool=$1
|
||||
local install_message=$2
|
||||
@@ -60,9 +55,12 @@ check_required_tool() {
|
||||
}
|
||||
echo "Checking for required tools..."
|
||||
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
|
||||
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
|
||||
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
|
||||
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -260,7 +258,7 @@ combine_static_libraries() {
|
||||
|
||||
# Since we have multiple architectures libtool will find object files that do not
|
||||
# match the target architecture. We suppress these warnings.
|
||||
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
|
||||
# Determine SDK, architectures, and install_name based on platform and simulator flag.
|
||||
local sdk=""
|
||||
@@ -333,7 +331,7 @@ combine_static_libraries() {
|
||||
|
||||
# Platform-specific post-processing for device builds
|
||||
if [[ "$is_simulator" == "false" ]]; then
|
||||
if command -v xcrun vtool &>/dev/null; then
|
||||
if xcrun -f vtool &>/dev/null; then
|
||||
case "$platform" in
|
||||
"ios")
|
||||
echo "Marking binary as a framework binary for iOS..."
|
||||
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xros \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
@@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
|
||||
|
||||
# Create XCFramework with correct debug symbols paths
|
||||
echo "Creating XCFramework..."
|
||||
xcodebuild -create-xcframework \
|
||||
xcrun xcodebuild -create-xcframework \
|
||||
-framework $(pwd)/build-ios-sim/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-macos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
|
||||
|
||||
+11
-27
@@ -5,7 +5,6 @@ find_package(Threads REQUIRED)
|
||||
llama_add_compile_flags()
|
||||
|
||||
# Build info header
|
||||
#
|
||||
|
||||
if(EXISTS "${PROJECT_SOURCE_DIR}/.git")
|
||||
set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git")
|
||||
@@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
|
||||
if (LLAMA_HTTPLIB)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
build_info
|
||||
cpp-httplib
|
||||
)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
|
||||
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
|
||||
|
||||
# Set the correct library file extension based on platform
|
||||
if (WIN32)
|
||||
set(LLGUIDANCE_LIB_NAME "llguidance.lib")
|
||||
# Add Windows-specific libraries
|
||||
set(LLGUIDANCE_PLATFORM_LIBS
|
||||
ws2_32 # Windows Sockets API
|
||||
userenv # For GetUserProfileDirectoryW
|
||||
ntdll # For NT functions
|
||||
bcrypt # For BCryptGenRandom
|
||||
)
|
||||
else()
|
||||
set(LLGUIDANCE_LIB_NAME "libllguidance.a")
|
||||
set(LLGUIDANCE_PLATFORM_LIBS "")
|
||||
endif()
|
||||
set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
|
||||
ExternalProject_Add(llguidance_ext
|
||||
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
|
||||
@@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE)
|
||||
add_dependencies(llguidance llguidance_ext)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
|
||||
# Add platform libraries to the main target
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
|
||||
endif ()
|
||||
target_link_libraries(${TARGET} PRIVATE llguidance)
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
|
||||
target_link_libraries(${TARGET} PUBLIC llama Threads::Threads)
|
||||
|
||||
+1
-1
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
||||
+15
-4
@@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
bool last_was_media_marker = false;
|
||||
// join parts with newline, do not add newline before or after media markers
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
bool add_new_line = true;
|
||||
if (part.type == "text") {
|
||||
add_new_line = !last_was_media_marker && !text.empty();
|
||||
last_was_media_marker = false;
|
||||
} else if (part.type == "media_marker") {
|
||||
add_new_line = false;
|
||||
last_was_media_marker = true;
|
||||
} else {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
|
||||
if (add_new_line) {
|
||||
text += '\n';
|
||||
}
|
||||
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
@@ -319,7 +330,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
throw std::invalid_argument("Missing content part type: " + part.dump());
|
||||
}
|
||||
const auto & type = part.at("type");
|
||||
if (type != "text") {
|
||||
if (type != "text" && type != "media_marker") {
|
||||
throw std::invalid_argument("Unsupported content part type: " + type.dump());
|
||||
}
|
||||
common_chat_msg_content_part msg_part;
|
||||
@@ -3307,7 +3318,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto content = msg.content;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
if (part.type != "text" && part.type != "media_marker") {
|
||||
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
+32
-135
@@ -1,7 +1,3 @@
|
||||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
@@ -9,12 +5,12 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
@@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix) {
|
||||
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix) {
|
||||
bool has_suffix = string_ends_with(str, suffix);
|
||||
if (has_suffix) {
|
||||
str = str.substr(0, str.size() - suffix.size());
|
||||
}
|
||||
return has_suffix;
|
||||
}
|
||||
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const char text_last_char = str.back();
|
||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
||||
if (stop[char_index] == text_last_char) {
|
||||
const auto current_partial = stop.substr(0, char_index + 1);
|
||||
if (string_ends_with(str, current_partial)) {
|
||||
return str.size() - char_index - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string & s) {
|
||||
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
||||
return std::regex_replace(s, special_chars, "\\$&");
|
||||
@@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::u32string filename_utf32;
|
||||
try {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
size_t offset = 0;
|
||||
while (offset < filename.size()) {
|
||||
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
|
||||
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
||||
// or invalid encodings were encountered. Reject such attempts
|
||||
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
||||
if (filename_reencoded != filename) {
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
uint32_t c = result.codepoint;
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
for (char32_t c : filename_utf32) {
|
||||
if ((result.bytes_consumed == 2 && c < 0x80) ||
|
||||
(result.bytes_consumed == 3 && c < 0x800) ||
|
||||
(result.bytes_consumed == 4 && c < 0x10000)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
if (c <= 0x1F // Control characters (C0)
|
||||
|| c == 0x7F // Control characters (DEL)
|
||||
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
||||
@@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
||||
|| c == 0x2216 // Set Minus (backslash equivalent)
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c > 0x10FFFF // Max Unicode limit
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
@@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
offset += result.bytes_consumed;
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
@@ -898,7 +851,8 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
@@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
int err = llama_set_adapter_cvec(
|
||||
lctx,
|
||||
cvec.data.data(),
|
||||
cvec.data.size(),
|
||||
@@ -1344,12 +1298,15 @@ std::string get_model_endpoint() {
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
llama_clear_adapter_lora(ctx);
|
||||
for (auto & la : lora) {
|
||||
if (la.scale != 0.0f) {
|
||||
llama_set_adapter_lora(ctx, la.ptr, la.scale);
|
||||
}
|
||||
std::vector<llama_adapter_lora *> loras;
|
||||
std::vector<float> scales;
|
||||
|
||||
for (auto & la: lora) {
|
||||
loras.push_back(la.ptr);
|
||||
scales.push_back(la.scale);
|
||||
}
|
||||
|
||||
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
|
||||
}
|
||||
|
||||
struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
@@ -1469,66 +1426,6 @@ void common_batch_add(
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
|
||||
// check for empty sequences
|
||||
if (a.empty() || b.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input sequences
|
||||
size_t a_len = a.size();
|
||||
size_t b_len = b.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
size_t max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<size_t> prev_row(b_len + 1, 0);
|
||||
std::vector<size_t> curr_row(b_len + 1, 0);
|
||||
|
||||
// iterate through the elements of a
|
||||
for (size_t i = 1; i <= a_len; i++) {
|
||||
// iterate through the elements of b
|
||||
for (size_t j = 1; j <= b_len; j++) {
|
||||
// if elements at the current positions match
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
// if it's the first element of either sequences, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous element
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if elements don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
+41
-26
@@ -670,30 +670,55 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
}
|
||||
|
||||
template<>
|
||||
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||
inline std::vector<std::string> string_split<std::string>(const std::string & str, char delim)
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
size_t begin_pos = 0;
|
||||
size_t separator_pos = input.find(separator);
|
||||
while (separator_pos != std::string::npos) {
|
||||
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
||||
size_t delim_pos = str.find(delim);
|
||||
while (delim_pos != std::string::npos) {
|
||||
std::string part = str.substr(begin_pos, delim_pos - begin_pos);
|
||||
parts.emplace_back(part);
|
||||
begin_pos = separator_pos + 1;
|
||||
separator_pos = input.find(separator, begin_pos);
|
||||
begin_pos = delim_pos + 1;
|
||||
delim_pos = str.find(delim, begin_pos);
|
||||
}
|
||||
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
||||
parts.emplace_back(str.substr(begin_pos));
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool string_starts_with(const std::string & str,
|
||||
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
// remove when moving to c++20
|
||||
inline bool string_starts_with(std::string_view str, std::string_view prefix) {
|
||||
return str.size() >= prefix.size() &&
|
||||
str.compare(0, prefix.size(), prefix) == 0;
|
||||
}
|
||||
|
||||
// While we wait for C++20's std::string::ends_with...
|
||||
bool string_ends_with(const std::string_view & str, const std::string_view & suffix);
|
||||
bool string_remove_suffix(std::string & str, const std::string_view & suffix);
|
||||
size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop);
|
||||
// remove when moving to c++20
|
||||
inline bool string_ends_with(std::string_view str, std::string_view suffix) {
|
||||
return str.size() >= suffix.size() &&
|
||||
str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0;
|
||||
}
|
||||
|
||||
inline bool string_remove_suffix(std::string & str, std::string_view suffix) {
|
||||
if (string_ends_with(str, suffix)) {
|
||||
str.resize(str.size() - suffix.size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) {
|
||||
if (!str.empty() && !stop.empty()) {
|
||||
const size_t max_len = std::min(str.size(), stop.size());
|
||||
const char last_char = str.back();
|
||||
for (size_t len = max_len; len > 0; --len) {
|
||||
if (stop[len - 1] == last_char) {
|
||||
if (string_ends_with(str, stop.substr(0, len))) {
|
||||
return str.size() - len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::string::npos;
|
||||
}
|
||||
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
||||
void string_process_escapes(std::string & input);
|
||||
@@ -779,16 +804,6 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
// longest common prefix
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
// longet common subsequence
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
@@ -880,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
inline std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
|
||||
+78
-136
@@ -19,9 +19,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
|
||||
}
|
||||
|
||||
static std::string read_etag(const std::string & path) {
|
||||
std::string none;
|
||||
const std::string etag_path = path + ".etag";
|
||||
|
||||
if (std::filesystem::exists(etag_path)) {
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return none;
|
||||
}
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
if (!std::filesystem::exists(etag_path)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// no etag file, but maybe there is an old .json
|
||||
// remove this code later
|
||||
const std::string metadata_path = path + ".json";
|
||||
|
||||
if (std::filesystem::exists(metadata_path)) {
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
try {
|
||||
nlohmann::json metadata_json;
|
||||
metadata_in >> metadata_json;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata_json.dump().c_str());
|
||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||
std::string etag = metadata_json.at("etag");
|
||||
write_etag(path, etag);
|
||||
if (!std::filesystem::remove(metadata_path)) {
|
||||
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
|
||||
}
|
||||
return etag;
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return {};
|
||||
}
|
||||
return none;
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
}
|
||||
|
||||
static bool is_http_status_ok(int status) {
|
||||
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
class ProgressBar {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
@@ -305,7 +275,10 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
);
|
||||
|
||||
if (!res) {
|
||||
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
||||
LOG_ERR("%s: download failed: %s (status: %d)\n",
|
||||
__func__,
|
||||
httplib::to_string(res.error()).c_str(),
|
||||
res ? res->status : -1);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -344,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
auto head = cli.Head(parts.path);
|
||||
bool head_ok = head && head->status >= 200 && head->status < 300;
|
||||
if (!head_ok) {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head->status; // cannot use cached file, return raw status code
|
||||
// TODO: maybe retry only on certain codes
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head_ok && head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head_ok && head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head_ok && head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
bool should_download_from_scratch = false;
|
||||
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_etag.c_str(), etag.c_str());
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
|
||||
auto head = cli.Head(parts.path);
|
||||
if (!head || head->status < 200 || head->status >= 300) {
|
||||
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
if (!should_download_from_scratch) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head ? head->status : -1;
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
if (file_exists) {
|
||||
if (etag.empty()) {
|
||||
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (!last_etag.empty() && last_etag == etag) {
|
||||
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
delay *= retry_delay_seconds;
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
size_t existing_size = 0;
|
||||
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (supports_ranges && !should_download_from_scratch) {
|
||||
if (supports_ranges) {
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
@@ -407,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
}
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
|
||||
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
|
||||
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
} else {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(),
|
||||
path_temporary.c_str(), etag.c_str());
|
||||
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
continue;
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
}
|
||||
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
|
||||
return head->status; // TODO: use actual GET status?
|
||||
}
|
||||
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
@@ -798,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
|
||||
+41
-2
@@ -4,6 +4,7 @@
|
||||
// for converting from JSON to jinja values
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
@@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
return args.get_pos(0);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"indent", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String indent builtin not implemented");
|
||||
{"indent", [](const func_args &args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
value val_input = args.get_pos(0);
|
||||
value val_width = args.get_kwarg_or_pos("width", 1);
|
||||
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
|
||||
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
|
||||
if (!is_val<value_string>(val_input)) {
|
||||
throw raised_exception("indent() first argument must be a string");
|
||||
}
|
||||
std::string indent;
|
||||
if (is_val<value_int>(val_width)) {
|
||||
indent.assign(val_width->as_int(), ' ');
|
||||
} else if (is_val<value_string>(val_width)) {
|
||||
indent = val_width->as_string().str();
|
||||
} else {
|
||||
indent = " ";
|
||||
}
|
||||
std::string indented;
|
||||
std::string input = val_input->as_string().str();
|
||||
std::istringstream iss = std::istringstream(input);
|
||||
std::string line;
|
||||
while (std::getline(iss, line)) {
|
||||
if (!indented.empty()) {
|
||||
indented.push_back('\n');
|
||||
}
|
||||
if ((indented.empty() ? first : (!line.empty() || blank))) {
|
||||
indented += indent;
|
||||
}
|
||||
indented += line;
|
||||
}
|
||||
if (!input.empty() && input.back() == '\n') {
|
||||
indented.push_back('\n');
|
||||
if (blank) {
|
||||
indented += indent;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = mk_val<value_string>(indented);
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
|
||||
@@ -461,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
// What is sum of the other occurrences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
|
||||
+2
-2
@@ -44,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
@@ -53,7 +53,7 @@ struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
uint16_t key_num; // number of occurrences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
|
||||
+514
-130
File diff suppressed because it is too large
Load Diff
@@ -99,6 +99,7 @@ models = [
|
||||
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
@@ -113,6 +114,7 @@ models = [
|
||||
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
||||
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
||||
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
||||
{"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
|
||||
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||
@@ -148,6 +150,8 @@ models = [
|
||||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -157,6 +161,7 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
|
||||
@@ -170,7 +175,6 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ cmake --build build --config release
|
||||
|
||||
1. **Retrieve and prepare model**
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration.
|
||||
|
||||
**Notes**:
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt`
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
@@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin**
|
||||
|
||||
#### Retrieve and prepare model
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
|
||||
|
||||
##### Check device
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
|
||||
+3
-3
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
|------------|-------------|------|-------|
|
||||
| FP32 | ✅ | ✅ | ❓ |
|
||||
| FP16 | ✅ | ✅ | ❓ |
|
||||
| BF16 | 🚫 | ✅ | ❓ |
|
||||
| BF16 | ✅ | ✅ | ❓ |
|
||||
| Q4_0 | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ❓ | ❓ |
|
||||
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
|
||||
|
||||
@@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"):
|
||||
config = config.text_config
|
||||
multimodal = True
|
||||
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
print("Hidden size: ", config.hidden_size)
|
||||
print("Number of layers: ", config.num_hidden_layers)
|
||||
print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
def print_if_exists(label, obj, attr, default="N/A"):
|
||||
val = getattr(obj, attr) if hasattr(obj, attr) else default
|
||||
print(f"{label}", val)
|
||||
|
||||
print_if_exists("Vocab size: ", config, "vocab_size")
|
||||
print_if_exists("Hidden size: ", config, "hidden_size")
|
||||
print_if_exists("Number of layers: ", config, "num_hidden_layers")
|
||||
print_if_exists("BOS token id: ", config, "bos_token_id")
|
||||
print_if_exists("EOS token id: ", config, "eos_token_id")
|
||||
|
||||
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
|
||||
if unreleased_model_name:
|
||||
|
||||
@@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False):
|
||||
print(tensor_name)
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str):
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
@@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str):
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
@@ -127,6 +133,15 @@ def main():
|
||||
action="store_true",
|
||||
help="List unique tensor patterns in the model (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -152,7 +167,7 @@ def main():
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+1
-1
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 5)
|
||||
set(GGML_VERSION_PATCH 7)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
+1
-4
@@ -730,10 +730,6 @@ extern "C" {
|
||||
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||
"use ggml_row_size() instead");
|
||||
|
||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||
@@ -752,6 +748,7 @@ extern "C" {
|
||||
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
||||
|
||||
@@ -17,11 +17,6 @@
|
||||
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
|
||||
#define AT_PRINTF(...)
|
||||
|
||||
|
||||
static bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
|
||||
hn->allocated = true;
|
||||
assert(hn->addr.offset == 0);
|
||||
|
||||
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
|
||||
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
|
||||
// itself is never used and should not be considered a dependency
|
||||
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
|
||||
}
|
||||
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
|
||||
|
||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
view_src_hn->n_views -= 1;
|
||||
|
||||
@@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
|
||||
int best_score = 0;
|
||||
fs::path best_path;
|
||||
std::error_code ec;
|
||||
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (std::error_code ec; !fs::exists(search_path, ec)) {
|
||||
if (!fs::exists(search_path, ec)) {
|
||||
if (ec) {
|
||||
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
|
||||
} else {
|
||||
@@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
if (entry.is_regular_file()) {
|
||||
if (entry.is_regular_file(ec)) {
|
||||
auto filename = entry.path().filename();
|
||||
auto ext = entry.path().extension();
|
||||
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
|
||||
|
||||
@@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
|
||||
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
||||
};
|
||||
|
||||
// cann buffer type
|
||||
/**
|
||||
* @brief Check if a buffer is a CANN buffer.
|
||||
*
|
||||
* This function checks if a given buffer is a CANN buffer by comparing its
|
||||
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
|
||||
*
|
||||
* @param buffer The buffer to check.
|
||||
* @return true if the buffer is a CANN buffer, false otherwise.
|
||||
* @brief Structure representing context information for a specific backend
|
||||
* buffer type.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
|
||||
struct ggml_backend_cann_buffer_type_context {
|
||||
int32_t device; /**< Device identifier associated with the buffer context. */
|
||||
std::string name; /**< Name associated with the buffer context. */
|
||||
};
|
||||
|
||||
static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
|
||||
return ggml_backend_buft_is_cann(buffer->buft);
|
||||
/**
|
||||
* @brief Retrieves the name associated with a CANN buffer type.
|
||||
*
|
||||
* This function returns the descriptive name associated with the specified
|
||||
* CANN buffer type context.
|
||||
*
|
||||
* @param buft Pointer to the buffer type context.
|
||||
* @return Const pointer to the C-style string containing the name.
|
||||
*/
|
||||
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
|
||||
|
||||
return buft_ctx->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
||||
*
|
||||
* This function checks whether the provided backend buffer type is associated
|
||||
* with the CANN backend based on the comparison of its name retrieval function
|
||||
* pointer.
|
||||
*
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the buffer type is associated with the CANN
|
||||
* backend, otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst) {
|
||||
if (ggml_backend_buffer_is_cann(src->buffer)) {
|
||||
if (ggml_backend_buft_is_cann(src->buffer->buft)) {
|
||||
ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
|
||||
ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
|
||||
|
||||
@@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
// cann buffer type
|
||||
/**
|
||||
* @brief Structure representing context information for a specific backend
|
||||
* buffer type.
|
||||
*/
|
||||
struct ggml_backend_cann_buffer_type_context {
|
||||
int32_t device; /**< Device identifier associated with the buffer context. */
|
||||
std::string name; /**< Name associated with the buffer context. */
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Retrieves the name associated with a CANN buffer type.
|
||||
*
|
||||
* This function returns the descriptive name associated with the specified
|
||||
* CANN buffer type context.
|
||||
*
|
||||
* @param buft Pointer to the buffer type context.
|
||||
* @return Const pointer to the C-style string containing the name.
|
||||
*/
|
||||
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
||||
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
|
||||
|
||||
return buft_ctx->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allocates a new CANN buffer of the specified type and size.
|
||||
*
|
||||
@@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
|
||||
|
||||
GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
|
||||
|
||||
if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
|
||||
if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
||||
*
|
||||
* This function checks whether the provided backend buffer type is associated
|
||||
* with the CANN backend based on the comparison of its name retrieval function
|
||||
* pointer.
|
||||
*
|
||||
* @param buft Pointer to the backend buffer type to check.
|
||||
* @return bool Returns true if the buffer type is associated with the CANN
|
||||
* backend, otherwise false.
|
||||
*/
|
||||
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Records an event on the CANN backend stream.
|
||||
*
|
||||
|
||||
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
# Disable LTO for the feature detection code to prevent cross-module optimization
|
||||
# from inlining architecture-specific instructions into the score function.
|
||||
# Without this, LTO can cause SIGILL when loading backends on older CPUs
|
||||
# (e.g., loading power10 backend on power9 crashes before feature check runs).
|
||||
target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
|
||||
target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
|
||||
endfunction()
|
||||
|
||||
@@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
|
||||
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
message(FATAL_ERROR "KleidiAI source downloaded failed.")
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
||||
# Remove kleidiai target after fetching it
|
||||
if (TARGET kleidiai)
|
||||
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/kleidiai/kleidiai.cpp
|
||||
ggml-cpu/kleidiai/kernels.cpp
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -55,7 +56,8 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -76,6 +78,7 @@
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -84,6 +87,7 @@
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -107,6 +111,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -119,6 +124,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -143,6 +149,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -155,6 +162,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -186,6 +194,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -197,6 +206,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -227,6 +237,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -239,6 +250,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
@@ -271,6 +283,7 @@
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
@@ -283,6 +296,7 @@
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
|
||||
@@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
||||
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[2];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
|
||||
|
||||
int32x4_t acc[col_groups];
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
|
||||
// Reused for bias and dequantization later
|
||||
int16_t q6_scales[16 * 8];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
||||
vst1q_s16(q6_scales + i * 8, scales);
|
||||
}
|
||||
|
||||
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
|
||||
int32x4_t bias_lo = vdupq_n_s32(0);
|
||||
int32x4_t bias_hi = vdupq_n_s32(0);
|
||||
|
||||
// Load bsums in chunks of 4 to process with vectorized operations
|
||||
for (int i = 0; i < 16; i += 4) {
|
||||
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
|
||||
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
|
||||
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
|
||||
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
|
||||
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
|
||||
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
|
||||
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
|
||||
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
|
||||
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
|
||||
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
|
||||
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
|
||||
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
|
||||
}
|
||||
bias_lo = vshlq_n_s32(bias_lo, 5);
|
||||
bias_hi = vshlq_n_s32(bias_hi, 5);
|
||||
|
||||
// Process two 128-value halves per superblock
|
||||
for (int half = 0; half < 2; half++) {
|
||||
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
||||
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
||||
|
||||
// A subblock (sb) is a set of weights that share the scale
|
||||
// Since q6_K scales are per 16 elements
|
||||
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
|
||||
const int8_t * q8_base_h = q8_base_l + 64;
|
||||
|
||||
// Load and duplicate q8 values (each register covers four interleaved columns of q6)
|
||||
int8x16_t q8_l[4];
|
||||
int8x16_t q8_h[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
|
||||
q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
|
||||
}
|
||||
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
||||
|
||||
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
||||
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
||||
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
||||
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
|
||||
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
||||
if (sb > 1) {
|
||||
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
|
||||
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
|
||||
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
|
||||
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
|
||||
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
|
||||
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
|
||||
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
|
||||
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
|
||||
}
|
||||
|
||||
const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
|
||||
q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
|
||||
const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
|
||||
q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
|
||||
|
||||
// Process column groups (0-3, 4-7)
|
||||
for (int g = 0; g < col_groups; g++) {
|
||||
int32x4_t sb_acc_l = vdupq_n_s32(0);
|
||||
int32x4_t sb_acc_h = vdupq_n_s32(0);
|
||||
|
||||
for (int chunk = 0; chunk < 4; chunk++) {
|
||||
const int idx = chunk * 2 + g;
|
||||
|
||||
const uint8x16_t q6_qs_l = q6_ql[idx];
|
||||
const uint8x16_t q6_qs_h = q6_qh[idx];
|
||||
|
||||
// Extract high 2 bits for upper nibble reconstruction
|
||||
const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
|
||||
|
||||
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
|
||||
const int8x16_t q6_l =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
|
||||
const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
|
||||
|
||||
sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
|
||||
sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
|
||||
}
|
||||
|
||||
const int scale_idx_l = half * 8 + sb;
|
||||
const int scale_idx_h = half * 8 + sb + 4;
|
||||
|
||||
const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
|
||||
const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
|
||||
|
||||
acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
|
||||
acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
|
||||
}
|
||||
}
|
||||
} // for half
|
||||
|
||||
// Bias correction
|
||||
acc[0] = vsubq_s32(acc[0], bias_lo);
|
||||
acc[1] = vsubq_s32(acc[1], bias_hi);
|
||||
|
||||
// Apply superblock scale (no mins for q6_K)
|
||||
// acc[g] has [c0, c1, c2, c3]
|
||||
float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
|
||||
float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
|
||||
|
||||
acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
|
||||
acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n,
|
||||
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
|
||||
}
|
||||
|
||||
// TODO: Test other qh repack patterns to reduce loads
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
|
||||
|
||||
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
|
||||
ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
|
||||
ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
|
||||
ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
|
||||
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
|
||||
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
|
||||
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
|
||||
|
||||
// Adjust qh for subblocks 2 and 3 (shift right by 2)
|
||||
if (sb > 1) {
|
||||
@@ -3038,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
|
||||
int32_t bsums_arr32[4][8];
|
||||
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
|
||||
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01);
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23);
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01);
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23);
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
@@ -3474,6 +3972,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
|
||||
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
|
||||
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
|
||||
const int8x16_t m32s = vdupq_n_s8(32);
|
||||
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
|
||||
float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
|
||||
|
||||
int32x4_t acc_s32[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_s32[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
int16_t q6_scales[8 * 16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
|
||||
vst1q_s16(q6_scales + i * 8, scales);
|
||||
}
|
||||
|
||||
for (int half = 0; half < 2; half++) {
|
||||
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
|
||||
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
|
||||
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
|
||||
|
||||
// 4 rows * 16 elements per scale
|
||||
// 4 reads of 16 bytes each
|
||||
constexpr int reads_per_sb = 4;
|
||||
int8x16_t q8_l[reads_per_sb];
|
||||
int8x16_t q8_h[reads_per_sb];
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
|
||||
q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
|
||||
}
|
||||
|
||||
const int ql_off_base = sb * QK_K / 2;
|
||||
const int qh_off_base = ql_off_base & 255;
|
||||
|
||||
uint8x16_t q6_ql_0123[reads_per_sb];
|
||||
uint8x16_t q6_ql_4567[reads_per_sb];
|
||||
uint8x16_t q6_qh_0123[reads_per_sb];
|
||||
uint8x16_t q6_qh_4567[reads_per_sb];
|
||||
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
|
||||
q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
|
||||
q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
|
||||
q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
|
||||
}
|
||||
|
||||
if (sb > 1) {
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
|
||||
q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
|
||||
}
|
||||
}
|
||||
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
// q = (ql | qh) - 32
|
||||
const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
|
||||
const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
|
||||
const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
|
||||
const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
|
||||
|
||||
const int8x16_t q6_0123_lo = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
|
||||
const int8x16_t q6_0123_hi = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
|
||||
|
||||
const int8x16_t q6_4567_lo = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
|
||||
const int8x16_t q6_4567_hi = vsubq_s8(
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias
|
||||
const int scale_idx_l = half * 8 + sb;
|
||||
const int scale_idx_h = half * 8 + sb + 4;
|
||||
|
||||
for (int g = 0; g < col_groups; g++) {
|
||||
const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
|
||||
const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
|
||||
const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
|
||||
const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
|
||||
const int acc_offset = g * q8_k_blocklen;
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
const int idx = row * 2 + g;
|
||||
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
|
||||
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally we apply the superblock scales
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
const int idx0 = 2 * row;
|
||||
const int idx1 = 2 * row + 1;
|
||||
const int32x4_t acc_0123 = acc_s32[idx0];
|
||||
const int32x4_t acc_4567 = acc_s32[idx1];
|
||||
|
||||
acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
|
||||
acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
||||
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
||||
|
||||
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
}
|
||||
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_fn_t vDSP_op = nullptr;
|
||||
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
if (is_src1_contiguous) {
|
||||
if (is_src1_contiguous_rows) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
#define GGML_FA_TILE_Q 64
|
||||
#define GGML_FA_TILE_KV 64
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
||||
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
|
||||
|
||||
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
||||
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
||||
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.use_ref =*/ cplan->use_ref,
|
||||
};
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth);
|
||||
|
||||
void matmul(int64_t m, int64_t n);
|
||||
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc*kc*2];
|
||||
vec_t B_pack[nc*kc*2];
|
||||
int comparray[mc*kc];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
|
||||
} else {
|
||||
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
|
||||
}
|
||||
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
|
||||
*c_ptr += *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
}
|
||||
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
template<int size>
|
||||
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj);
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
|
||||
template <int RM, int RN>
|
||||
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
|
||||
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
|
||||
for (int I = 0; I<8; I++) {
|
||||
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
|
||||
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q8_elements(const int8_t *qs, int *ca) {
|
||||
vector signed char c1 = vec_xl(0, qs);
|
||||
vector signed char c2 = vec_xl(16, qs);
|
||||
vector signed int vsum1 = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
vsum1 = vec_sum4s(c1, vsum1);
|
||||
vsum2 = vec_sum4s(c2, vsum2);
|
||||
vector signed int vsum = vec_add(vsum1, vsum2);
|
||||
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
int index = 0;
|
||||
if (j > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] = aoffset + it*lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
if (comparray){
|
||||
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
|
||||
}
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while(j > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
vecOffset = vec;
|
||||
int index = 0;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
|
||||
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
|
||||
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[index + 8*blk+0]);
|
||||
process_q4_elements(c2, &comparray[index + 8*blk+1]);
|
||||
process_q4_elements(c3, &comparray[index + 8*blk+2]);
|
||||
process_q4_elements(c4, &comparray[index + 8*blk+3]);
|
||||
process_q4_elements(c5, &comparray[index + 8*blk+4]);
|
||||
process_q4_elements(c6, &comparray[index + 8*blk+5]);
|
||||
process_q4_elements(c7, &comparray[index + 8*blk+6]);
|
||||
process_q4_elements(c8, &comparray[index + 8*blk+7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while (j > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 8) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
vector float fin_res[16] = {0};
|
||||
vector float vs[16] = {0};
|
||||
for (int64_t kk = 0; kk < kc; kk+=2) {
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
int A_block_idx = (i/8)*(16*kc) + kk*16;
|
||||
int B_block_idx = (j/8)*(16*kc)+ kk*16;
|
||||
vec_t *A_block = &vec_A[A_block_idx];
|
||||
vec_t *B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk, vs);
|
||||
int c_index = (i/8)*(8*kc)+ kk*8;
|
||||
int* c_block = &comparray[c_index];
|
||||
compute(&acc[0], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[1], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[2], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[3], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
|
||||
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
|
||||
A_block = &vec_A[A_block_idx];
|
||||
B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk+1, vs);
|
||||
c_index = (i/8)*(8*kc)+ (kk+1)*8;
|
||||
c_block = &comparray[c_index];
|
||||
compute(&acc[4], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[5], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[6], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[7], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
}
|
||||
if (l == 0) {
|
||||
save_res(ii+i, jj+j, 0, fin_res);
|
||||
save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
} else {
|
||||
add_save_res(ii+i, jj+j, 0, fin_res);
|
||||
add_save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
add_save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const block_q8_0 *const B;
|
||||
float *C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||
#endif
|
||||
|
||||
#if defined(__MMA__)
|
||||
#include "sgemm-ppc.h"
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
|
||||
const int nth;
|
||||
};
|
||||
|
||||
template <typename TA>
|
||||
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA * A, int64_t lda,
|
||||
const block_q8_0 * B, int64_t ldb,
|
||||
float * C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
kc = 64;
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
|
||||
int mc = 64; int nc = 64;
|
||||
if (n % 8 == 0 && n < nc) {
|
||||
nc = n;
|
||||
mc = 32 ;
|
||||
kc = 32;
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
const int64_t mc = 64;
|
||||
const int64_t kc = 64;
|
||||
int64_t nc = 64;
|
||||
int64_t n_aligned = 0;
|
||||
if (n % 64 == 0) {
|
||||
n_aligned = n;
|
||||
} else if (n == 4) {
|
||||
n_aligned = 4;
|
||||
} else if (n < 64) {
|
||||
n_aligned = (n / 8) * 8;
|
||||
} else {
|
||||
n_aligned = (n / 64) * 64;
|
||||
}
|
||||
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
|
||||
if (is_aligned) {
|
||||
this->matmul_tiled_q0(m, n, mc, nc, kc);
|
||||
|
||||
if (n_aligned > 0) {
|
||||
if (n_aligned % 64 == 0) nc = 64;
|
||||
else if (n_aligned == n) nc = n;
|
||||
else if (n_aligned % 32 == 0) nc = 32;
|
||||
else if (n_aligned % 24 == 0) nc = 24;
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
mnpack(0, m, n_aligned, n);
|
||||
}
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int size>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
||||
*c_ptr += *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 16);
|
||||
vec_xst(t7, 0, vecOffset + 32);
|
||||
vec_xst(t8, 0, vecOffset + 48);
|
||||
}
|
||||
|
||||
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
||||
const vector signed char v8 = vec_splats((signed char)0x08);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
||||
lo = vec_and(packed, lowMask);
|
||||
hi = vec_sr(packed, v4);
|
||||
lo = vec_sub(lo, v8);
|
||||
hi = vec_sub(hi, v8);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
||||
vec_t t[8], s[8];
|
||||
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
for (int i = 4; i < 8; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||
s[4] = vec_perm(t[4], t[6], swiz3);
|
||||
s[5] = vec_perm(t[4], t[6], swiz4);
|
||||
s[6] = vec_perm(t[5], t[7], swiz3);
|
||||
s[7] = vec_perm(t[5], t[7], swiz4);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
||||
vector signed short i16_hi = vec_unpackh(raw);
|
||||
vector signed short i16_lo = vec_unpackl(raw);
|
||||
|
||||
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
||||
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
||||
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
||||
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
||||
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
||||
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
||||
}
|
||||
|
||||
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
for (int i = 0; i < rows; i += 8) {
|
||||
const block_q4_0 * rows_base[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[8][4];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
const block_q4_0 * current_blk = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
||||
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
|
||||
vector signed char c1, c2;
|
||||
unpack_q4_to_q8(v_qs, c1, c2);
|
||||
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int c = 0; c < 4; c++) {
|
||||
vector unsigned char c_arr[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
||||
}
|
||||
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
||||
vecOffset += 128;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int chunk_size>
|
||||
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
for (int i = 0; i < rows; i += chunk_size) {
|
||||
const block_q8_0 * rows_base[chunk_size];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[chunk_size][4];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
const block_q8_0 * b = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
||||
vector signed char c[2];
|
||||
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
||||
__builtin_vsx_disassemble_pair(c, & pair);
|
||||
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int col = 0; col < 4; col++) {
|
||||
if constexpr (chunk_size == 8) {
|
||||
vec_t t[8];
|
||||
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
||||
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
||||
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
||||
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
||||
vecOffset += 128;
|
||||
} else {
|
||||
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vecOffset += 64;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
if (rows == 4) {
|
||||
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
||||
} else {
|
||||
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
||||
}
|
||||
}
|
||||
|
||||
template<int size>
|
||||
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
TA * aoffset = NULL;
|
||||
int8_t * vecOffset = NULL;
|
||||
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
||||
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
aoffset = const_cast<TA *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c5, &comparray[4]);
|
||||
process_q4_elements(c6, &comparray[5]);
|
||||
process_q4_elements(c7, &comparray[6]);
|
||||
process_q4_elements(c8, &comparray[7]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
process_q4_elements(c5, & comparray[4]);
|
||||
process_q4_elements(c6, & comparray[5]);
|
||||
process_q4_elements(c7, & comparray[6]);
|
||||
process_q4_elements(c8, & comparray[7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
|
||||
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
break;
|
||||
}
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<typename VA, typename VB>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
block_q8_0 * aoffset = NULL;
|
||||
VA * vecOffset = NULL;
|
||||
block_q8_0 * aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
aoffset = const_cast<block_q8_0 *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 8; it++)
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 256;
|
||||
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 4; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 4; it++) {
|
||||
aoffsets[it] += lda;
|
||||
}
|
||||
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
|
||||
if (rows & 3) {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 3; it++ )
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
||||
c1[2] = c[2][0]; c2[2] = c[2][1];
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
||||
c1[1] = c[1][0]; c2[1] = c[1][1];
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
||||
c1[0] = c[0][0]; c2[0] = c[0][1];
|
||||
break;
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 3; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 128;
|
||||
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int m_rem = MIN(m - m0, 16);
|
||||
int n_rem = MIN(n - n0, 16);
|
||||
|
||||
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 4> comparray {};
|
||||
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
||||
}
|
||||
for (int I = 0; I<4; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii, jj+4, 4, fin_res);
|
||||
save_res(ii, jj + 4, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[8] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 8> comparray {};
|
||||
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
acc_t acc_4, acc_5, acc_6, acc_7;
|
||||
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[16] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_2);
|
||||
__builtin_mma_xxsetaccz(& acc_3);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8 ; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii, jj+4, 8, fin_res);
|
||||
save_res(ii+4, jj+4, 12, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
save_res(ii, jj + 4, 8, fin_res);
|
||||
save_res(ii + 4, jj + 4, 12, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 16) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
int A0_base = (i / 16) * (2 * 32 * kc);
|
||||
int B0_base = (j / 8) * (32 * kc);
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
for (int64_t kk = 0; kk < kc; kk++) {
|
||||
int A0_block_idx = A0_base + kk * 32;
|
||||
int B0_block_idx = B0_base + kk * 32;
|
||||
int A1_block_idx = A0_block_idx + 32 * kc;
|
||||
int B1_block_idx = B0_block_idx + 32 * kc;
|
||||
vec_t * A0_block = & vec_A[A0_block_idx];
|
||||
vec_t * B0_block = & vec_B[B0_block_idx];
|
||||
vec_t * A1_block = & vec_A[A1_block_idx];
|
||||
for (int it = 0; it < 4; it++) {
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (l == 0) {
|
||||
save_acc(& acc[0], ii + i, jj + j);
|
||||
save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
} else {
|
||||
add_save_acc(& acc[0], ii + i, jj + j);
|
||||
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc * kc * 4];
|
||||
vec_t B_pack[nc * kc * 4];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float fin_res[4] = {0};
|
||||
vector float vs[4] = {0};
|
||||
vector float CA[4] = {0};
|
||||
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
if (isAblock_q4) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x+=4) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
||||
for (int x = 0; x < 8; x += 4) {
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
||||
}
|
||||
for (int I = 0; I<RM; I++) {
|
||||
for (int J = 0; J<RN; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
this->kernel<RM, RN>(ii, jj);
|
||||
kernel<RM, RN>(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
template class tinyBLAS_Q0_PPC<block_q4_0>;
|
||||
template class tinyBLAS_Q0_PPC<block_q8_0>;
|
||||
const TA * const A;
|
||||
const block_q8_0 * const B;
|
||||
float * C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
|
||||
+173
-94
@@ -3,6 +3,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "binary-ops.h"
|
||||
#include "simd-gemm.h"
|
||||
#include "ggml.h"
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
@@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -8325,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
@@ -8360,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
@@ -8388,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
||||
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
||||
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
@@ -8412,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
{
|
||||
float * Q_f32 = (float *)Q_q;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
||||
}
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
||||
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
// Pad remaining mask entries with -inf
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
@@ -8441,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
||||
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
||||
}
|
||||
} else {
|
||||
const float * k_f32_src = (const float *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
||||
}
|
||||
}
|
||||
}
|
||||
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
||||
|
||||
// Set padded KQ entries to -inf so softmax gives them zero weight
|
||||
if (kv_tile < KV_TILE_SZ) {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8487,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
// V accumulation: VKQ32 += softmax(KQ) * V
|
||||
// Pack V tile to contiguous F32, zero-padded
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
||||
} else {
|
||||
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
||||
}
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) {
|
||||
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
||||
}
|
||||
}
|
||||
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
@@ -8730,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool use_tiled = !use_ref &&
|
||||
bool use_tiled = !use_ref &&
|
||||
(q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ);
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||
#endif
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
|
||||
+232
-198
@@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int blocks_per_half = 64 / blocklen;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0f;
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / blocklen;
|
||||
const int qh_pos_l = qh_idx_l % blocklen;
|
||||
const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / blocklen;
|
||||
const int qh_pos_h = qh_idx_h % blocklen;
|
||||
const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t a_l = a_ptr[l].qs[base_l + i];
|
||||
const int8_t a_h = a_ptr[l].qs[base_h + i];
|
||||
|
||||
sumi_l += q_l * a_l;
|
||||
sumi_h += q_h * a_h;
|
||||
}
|
||||
|
||||
sumf[j] +=
|
||||
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int blocks_per_half = 64 / blocklen;
|
||||
const int q8_half_stride = 512;
|
||||
const int q8_low_high_step = 256;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
|
||||
float sumf[4][8];
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / blocklen;
|
||||
const int qh_pos_l = qh_idx_l % blocklen;
|
||||
const int qh_offset_l =
|
||||
qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / blocklen;
|
||||
const int qh_pos_h = qh_idx_h % blocklen;
|
||||
const int qh_offset_h =
|
||||
qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
|
||||
const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
|
||||
|
||||
sumi_l += q_l * q8_l;
|
||||
sumi_h += q_h * q8_h;
|
||||
}
|
||||
|
||||
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
|
||||
a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
}
|
||||
|
||||
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0f;
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
|
||||
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
// qh indexing with 8-byte interleaving (like q5_K)
|
||||
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_byte_l / 8;
|
||||
const int qh_pos_l = qh_byte_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_byte_h / 8;
|
||||
const int qh_pos_h = qh_byte_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t a_l = a_ptr[l].qs[base_l + i];
|
||||
const int8_t a_h = a_ptr[l].qs[base_h + i];
|
||||
|
||||
sumi_l += q_l * a_l;
|
||||
sumi_h += q_h * a_h;
|
||||
}
|
||||
|
||||
sumf[j] +=
|
||||
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
|
||||
float sumf[4][8];
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
// Activation base indices for q8_Kx4 interleaved format
|
||||
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
|
||||
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / 8;
|
||||
const int qh_pos_l = qh_idx_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / 8;
|
||||
const int qh_pos_h = qh_idx_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
|
||||
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
|
||||
|
||||
sumi_l += q_l * q8_l;
|
||||
sumi_h += q_h * q8_h;
|
||||
}
|
||||
|
||||
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
|
||||
a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -1901,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
// buffer large enough for the max interleave block size (8 bytes)
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
|
||||
memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
|
||||
@@ -2097,18 +2113,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
|
||||
}
|
||||
|
||||
const int end_ls = QK_K * 4 / blck_size_interleave;
|
||||
// Interleave Q6_K quants by taking 8 bytes at a time
|
||||
// Interleave Q6_K quants by taking blck_size_interleave bytes at a time
|
||||
for (int i = 0; i < end_ls; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
int src_offset = (i / n_blocks) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_ls;
|
||||
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
|
||||
memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
|
||||
memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
|
||||
}
|
||||
|
||||
// Interleave high bits using same 8-byte pattern as low bits
|
||||
// Interleave high bits using same chunk size as low bits
|
||||
const int end_hs = end_ls / 2;
|
||||
for (int i = 0; i < end_hs; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
@@ -2116,8 +2132,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_hs;
|
||||
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
|
||||
memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
|
||||
memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is designed so as to unpack and rearrange scales in Q6_K
|
||||
@@ -2262,7 +2278,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
|
||||
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
|
||||
@@ -2511,6 +2527,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
@@ -2575,6 +2595,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2634,6 +2658,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -3043,6 +3071,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
@@ -3107,6 +3136,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q6_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q6_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
// Computes C[M x N] += A[M x K] * B[K x N]
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
// TODO: add support for sizeless vector types
|
||||
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
|
||||
|
||||
// TODO: untested on avx512
|
||||
// These are in units of GGML_F32_EPR
|
||||
#if defined(__AVX512F__) || defined (__ARM_NEON__)
|
||||
static constexpr int GEMM_RM = 4;
|
||||
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
static constexpr int GEMM_RM = 6;
|
||||
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
|
||||
#else
|
||||
static constexpr int GEMM_RM = 2;
|
||||
static constexpr int GEMM_RN = 2;
|
||||
#endif
|
||||
|
||||
template <int RM, int RN>
|
||||
static inline void simd_gemm_ukernel(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC acc[RM][RN];
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
GGML_F32_VEC Bv[RN];
|
||||
for (int r = 0; r < RN; r++) {
|
||||
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
||||
}
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C[M x N] += A[M x K] * B[K x N]
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
int64_t ii = 0;
|
||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
for (int64_t i = 0; i < GEMM_RM; i++) {
|
||||
float a = C[i * N + jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[i + kk] * B[kk * N + jj];
|
||||
}
|
||||
C[i * N + jj] = a;
|
||||
}
|
||||
}
|
||||
|
||||
A += GEMM_RM * K;
|
||||
C += GEMM_RM * N;
|
||||
}
|
||||
|
||||
// Tail rows: one at a time
|
||||
for (; ii < M; ii++) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
float a = C[jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[kk] * B[kk * N + jj];
|
||||
}
|
||||
C[jj] = a;
|
||||
}
|
||||
|
||||
A += K;
|
||||
C += N;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#else // scalar path
|
||||
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
float sum = C[i * N + j];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
sum += A[i * K + kk] * B[kk * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SIMD
|
||||
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||
res = tmp[0] + tmp[1]; \
|
||||
}
|
||||
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
|
||||
{ \
|
||||
float32x4_t v = vec_add(vec_add(s0, s1), \
|
||||
vec_add(s2, s3)); \
|
||||
v = vec_add(v, vec_sld(v, v, 8)); \
|
||||
v = vec_add(v, vec_sld(v, v, 4)); \
|
||||
res += (ggml_float)vec_extract(v, 0); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// BF16 s390x
|
||||
#define GGML_BF16_STEP 16
|
||||
#define GGML_BF16_EPR 8
|
||||
|
||||
#define GGML_BF16x8 __vector unsigned short
|
||||
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
|
||||
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
|
||||
|
||||
#define GGML_BF16_VEC GGML_BF16x8
|
||||
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
|
||||
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
|
||||
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_FMA_LO(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
|
||||
#define GGML_BF16_FMA_HI(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
||||
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
|
||||
|
||||
#endif
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
|
||||
const int np = (n & ~(GGML_BF16_STEP - 1));
|
||||
if (np > 0) {
|
||||
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
|
||||
|
||||
@@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND)
|
||||
FetchContent_Declare(
|
||||
CCCL
|
||||
GIT_REPOSITORY https://github.com/nvidia/cccl.git
|
||||
GIT_TAG v3.2.0-rc2
|
||||
GIT_TAG v3.2.0
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
|
||||
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
//size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00 ,s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,8 @@
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne0203, const uint3 ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
|
||||
|
||||
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
|
||||
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
const int64_t ib = ibx0 + i00/qk; // block index
|
||||
const int64_t iqs = (i00%qk)/qr; // quant index
|
||||
const int64_t iybs = i00 - i00%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
// dequantize
|
||||
float2 v;
|
||||
dequantize_kernel(vx, ib, iqs, v);
|
||||
|
||||
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
|
||||
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool need_check>
|
||||
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||
const int64_t ne0203, const uint3 ne02,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03) {
|
||||
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
|
||||
}
|
||||
|
||||
const int64_t i01 = blockIdx.y;
|
||||
const int64_t i02 = blockIdx.z % ne02;
|
||||
const int64_t i03 = blockIdx.z / ne02;
|
||||
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
|
||||
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
|
||||
const int64_t i02 = dm.y;
|
||||
const int64_t i03 = dm.x;
|
||||
|
||||
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
|
||||
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
|
||||
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * vx, dst_t * y,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
|
||||
const int64_t ne0203 = ne02*ne03;
|
||||
const uint3 ne02_fdv = init_fastdiv_values(ne02);
|
||||
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
|
||||
(vx, y, ne00, ne01, ne02, s01, s02, s03);
|
||||
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
|
||||
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
// On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
|
||||
// However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
|
||||
const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DV == 512) {
|
||||
|
||||
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
|
||||
#else
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
#endif
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
|
||||
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
|
||||
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
|
||||
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
|
||||
#else
|
||||
const half * K_h_f16 = K_h;
|
||||
const half * V_h_f16 = V_h;
|
||||
half * KQ_f16 = KQ;
|
||||
half * VKQ_f16 = VKQ;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
KQ_f16 + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, wmma::mem_col_major);
|
||||
}
|
||||
|
||||
@@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
|
||||
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
|
||||
if (ggml_is_quantized(src0->type)) {
|
||||
if (ne2 <= 4) {
|
||||
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
|
||||
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
|
||||
return;
|
||||
}
|
||||
@@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
}
|
||||
}
|
||||
|
||||
// note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
|
||||
// TODO: add asserts to verify this. should work with CUDA, HIP, etc.
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(nb12 % nb11 == 0);
|
||||
@@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
bool use_cuda_graph = true;
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD &&
|
||||
node->src[1] && node->src[1]->ne[1] > 1 &&
|
||||
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
|
||||
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
|
||||
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
|
||||
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3640,11 +3619,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
ggml_tensor fused_add_node;
|
||||
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
cgraph->nodes[i + n_fuse - 1]->data = node->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
|
||||
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
@@ -4542,6 +4523,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
// TODO: should become:
|
||||
//return ggml_is_contiguous_rows(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
@@ -4820,8 +4803,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_ACC:
|
||||
// TODO: extend support like so:
|
||||
//return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_TOP_K:
|
||||
|
||||
+21
-16
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XXS; ++l) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
||||
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
const float d = bxi->d;
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#else
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XS; ++l) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR3_XXS; ++l) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
|
||||
// v is a 7 bit int, with the 8th sign being encodable as popcnt
|
||||
// with xor we can "correct" the bit instead of having to mask
|
||||
const uint32_t p = __popc(v) & 1;
|
||||
const uint32_t s = v ^ p << 7;
|
||||
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
|
||||
return s * 0x01010101;
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < 8; k0 += 2) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
|
||||
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
|
||||
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
sumi = (ls*sumi + sumi/2)/4;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
|
||||
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
|
||||
return d * sumi;
|
||||
}
|
||||
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
int sumi1 = 0;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
|
||||
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
if (l0 < 4) {
|
||||
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
|
||||
|
||||
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * dst = op; // indices
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[0] > (16*1024)) {
|
||||
// reject tensors with huge rows for now
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
||||
case GGML_OP_SUB:
|
||||
req->op = HTP_OP_SUB;
|
||||
break;
|
||||
case GGML_OP_DIV:
|
||||
req->op = HTP_OP_DIV;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
|
||||
break;
|
||||
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_ARGSORT;
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
req->op = HTP_OP_SQR;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQRT:
|
||||
req->op = HTP_OP_SQRT;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
req->op = HTP_OP_GLU_SWIGLU_OAI;
|
||||
supported = true;
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
|
||||
req->op = HTP_OP_GLU_GEGLU;
|
||||
supported = true;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_SUM_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_ROPE;
|
||||
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
|
||||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
|
||||
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
}
|
||||
break;
|
||||
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
supp = ggml_hexagon_supported_binary(sess, op);
|
||||
break;
|
||||
|
||||
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SUM_ROWS:
|
||||
supp = ggml_hexagon_supported_sum_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SOFT_MAX:
|
||||
supp = ggml_hexagon_supported_softmax(sess, op);
|
||||
break;
|
||||
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
const auto glu_op = ggml_get_glu_op(op);
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include_directories(
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
sum-rows-ops.c
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
argsort-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
|
||||
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
|
||||
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
|
||||
|
||||
// geglu tanh implementation
|
||||
// geglu(x, g) = gelu(x) * g
|
||||
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
|
||||
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
act_op_func = unary_gelu_f32;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <math.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
struct htp_argsort_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
};
|
||||
|
||||
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
|
||||
{
|
||||
const HVX_Vector one = Q6_V_vsplat_R(1);
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
|
||||
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
|
||||
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
|
||||
return hvx_vec_get_i32(sum) == 32;
|
||||
}
|
||||
|
||||
// Sorts values and mirrors swaps to indices.
|
||||
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i
|
||||
while (i <= j) {
|
||||
// Check if we have at least one full vector
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
// If all elements are < pivot, we can skip this whole block
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback / cleanup
|
||||
if (values[i] < pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
// Load 32 elements ending at j.
|
||||
// Since we want `values[j] > pivot`, let's load from j-31 to j.
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] > pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i (values[i] > pivot)
|
||||
while (i <= j) {
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[i] > pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j (values[j] < pivot)
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] < pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
|
||||
// Unpack context
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Scratchpad memory
|
||||
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
|
||||
|
||||
// Dimensions
|
||||
uint32_t ne00 = src0->ne[0];
|
||||
uint32_t ne01 = src0->ne[1];
|
||||
uint32_t ne02 = src0->ne[2];
|
||||
uint32_t ne03 = src0->ne[3];
|
||||
|
||||
uint32_t nb01 = src0->nb[1];
|
||||
//uint32_t nb02 = src0->nb[2];
|
||||
//uint32_t nb03 = src0->nb[3];
|
||||
|
||||
uint32_t nb1 = dst->nb[1];
|
||||
//uint32_t nb2 = dst->nb[2];
|
||||
//uint32_t nb3 = dst->nb[3];
|
||||
|
||||
// Sort order
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
|
||||
|
||||
// Rows to process
|
||||
uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
uint32_t rows_per_thread = actx->nrows_per_thread;
|
||||
uint32_t start_row = rows_per_thread * i;
|
||||
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// Scratchpad layout:
|
||||
// We need space for one row of float data (values) and one row of int32 indices.
|
||||
// values: ne00 * sizeof(float)
|
||||
// indices: ne00 * sizeof(int32_t)
|
||||
// Padded to 128 bytes.
|
||||
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
float * values_buf = (float *) spad;
|
||||
int32_t * indices_buf = (int32_t *) (spad + values_size);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
uint32_t src_offset = r * nb01;
|
||||
uint32_t dst_offset = r * nb1;
|
||||
|
||||
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
|
||||
|
||||
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
|
||||
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
|
||||
|
||||
// Initialize indices
|
||||
for (uint32_t j = 0; j < ne00; j++) {
|
||||
indices_buf[j] = j;
|
||||
}
|
||||
|
||||
// Sort values and mirror swaps to indices
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
} else {
|
||||
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
}
|
||||
|
||||
// Copy indices back to DDR
|
||||
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
int op_argsort(struct htp_ops_context * octx) {
|
||||
// Check supported types
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
|
||||
size_t spad_per_thread = values_size + indices_size;
|
||||
|
||||
// Make sure we round up to 256 for alignment requirements
|
||||
spad_per_thread = hex_round_up(spad_per_thread, 256);
|
||||
|
||||
size_t total_spad_size = spad_per_thread * octx->n_threads;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad_size) {
|
||||
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.size = total_spad_size;
|
||||
octx->src0_spad.size_per_thread = spad_per_thread;
|
||||
|
||||
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
|
||||
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
|
||||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
|
||||
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
|
||||
|
||||
// Run jobs
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,121 +17,6 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
|
||||
const void * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
unsigned int n,
|
||||
float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
|
||||
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
// MAD: y (F32) += x (F16) * s (F32)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
|
||||
}
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
|
||||
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
float s0,
|
||||
float s1,
|
||||
int n) {
|
||||
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
|
||||
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(s0);
|
||||
HVX_Vector S1 = hvx_vec_splat_f16(s1);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
|
||||
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
|
||||
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs = xs_p_lo;
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i;
|
||||
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
struct htp_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
|
||||
struct fastdiv_values src0_div21;
|
||||
struct fastdiv_values src0_div1;
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values src3_div2;
|
||||
struct fastdiv_values src3_div3;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t n_blocks;
|
||||
|
||||
size_t size_q_row_padded;
|
||||
size_t size_k_row_padded;
|
||||
size_t size_v_row_padded;
|
||||
|
||||
size_t size_k_block;
|
||||
size_t size_v_block;
|
||||
size_t size_m_block;
|
||||
|
||||
bool is_q_fp32;
|
||||
};
|
||||
|
||||
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
||||
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
||||
}
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
||||
const struct htp_ops_context * octx = factx->octx;
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
|
||||
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
|
||||
|
||||
// Clear accumulator
|
||||
hvx_splat_f32_a(spad_a, 0, DV);
|
||||
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
if (factx->is_q_fp32) {
|
||||
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
||||
}
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
||||
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
}
|
||||
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
|
||||
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
||||
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
|
||||
M_vec = M_new_vec;
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
*(HVX_Vector *) p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
||||
}
|
||||
|
||||
// Sync scalars for leftover/next block if needed
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
s_val = factx->logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
M_vec = hvx_vec_splat_f32(M);
|
||||
S_vec = hvx_vec_splat_f32(S);
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
if (ib + 2 < factx->n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
// sinks
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
S = S * ms + vs;
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
|
||||
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
|
||||
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
factx.scale = scale;
|
||||
factx.max_bias = max_bias;
|
||||
factx.logit_softcap = logit_softcap;
|
||||
|
||||
uint32_t n_head = q->ne[2];
|
||||
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
|
||||
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
||||
|
||||
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
@@ -42,32 +42,36 @@ enum htp_data_type {
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// These values are manually translated over to HTP
|
||||
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT = 4,
|
||||
HTP_OP_MUL_MAT_ID = 5,
|
||||
HTP_OP_RMS_NORM = 6,
|
||||
HTP_OP_UNARY_SILU = 7,
|
||||
HTP_OP_UNARY_GELU = 8,
|
||||
HTP_OP_GLU_SWIGLU = 9,
|
||||
HTP_OP_GLU_SWIGLU_OAI = 10,
|
||||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
HTP_OP_CPY = 18,
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_type_block_size(uint32_t t) {
|
||||
static inline size_t htp_t_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const char * htp_type_name(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return "fp32";
|
||||
case HTP_TYPE_F16:
|
||||
return "fp16";
|
||||
case HTP_TYPE_Q4_0:
|
||||
return "q4_0";
|
||||
case HTP_TYPE_Q8_0:
|
||||
return "q8_0";
|
||||
case HTP_TYPE_MXFP4:
|
||||
return "mxfp4";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
|
||||
@@ -64,25 +64,12 @@ struct htp_ops_context {
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
|
||||
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
|
||||
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
|
||||
|
||||
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
|
||||
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
|
||||
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
@@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
@@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -46,127 +46,76 @@
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// ADD variants
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
|
||||
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
|
||||
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_BINARY_DISPATCHER(OP_NAME) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128)) { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_aau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_auu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
// SUB variants
|
||||
|
||||
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
// MUL variants
|
||||
|
||||
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
// Dispatchers
|
||||
|
||||
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
HVX_BINARY_DISPATCHER(hvx_add_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_sub_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_mul_f32)
|
||||
|
||||
// Mul-Mul Optimized
|
||||
|
||||
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Square
|
||||
//
|
||||
|
||||
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD
|
||||
#undef HVX_OP_SUB
|
||||
#undef HVX_OP_MUL
|
||||
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
#undef hvx_scalar_loop_body
|
||||
#undef HVX_OP_MIN_SCALAR
|
||||
#undef HVX_OP_CLAMP_SCALAR
|
||||
#undef DEFINE_HVX_BINARY_OP_VARIANTS
|
||||
#undef HVX_BINARY_DISPATCHER
|
||||
|
||||
#endif // HVX_ARITH_H
|
||||
|
||||
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
int32_t __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 4, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(__fp16); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
#ifndef HVX_DIV_H
|
||||
#define HVX_DIV_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-inverse.h"
|
||||
#include "hvx-arith.h"
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// 3-letter suffix variants
|
||||
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_aau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_auu(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_MUL
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t epv = 128 / sizeof(float); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
|
||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
||||
@@ -12,11 +12,17 @@
|
||||
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
//Algorithm :
|
||||
// x2 = input*0.5
|
||||
// y = * (long *) &input
|
||||
// y = 0x5f3759df - (y>>2)
|
||||
// y = 0x5f3759df - (y>>1)
|
||||
// y = y*(threehalfs - x2*y*y)
|
||||
|
||||
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
// Compute sqrt(x) as x*inv_sqrt(x)
|
||||
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vdst[i] = sqrt_res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
|
||||
if ((unsigned long) dst % 128 == 0) {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_SQRT_H */
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "hvx-sigmoid.h"
|
||||
#include "hvx-sqrt.h"
|
||||
#include "hvx-arith.h"
|
||||
#include "hvx-div.h"
|
||||
#include "hvx-base.h"
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
||||
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
|
||||
// otherwise we'll release it once we're done with the current Op.
|
||||
|
||||
if (ctx->vtcm_inuse) {
|
||||
ctx->vtcm_needs_release = false;
|
||||
ctx->vtcm_needs_release = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_argsort(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_sum_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
case HTP_OP_DIV:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad binary-req buffer list");
|
||||
continue;
|
||||
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SQR:
|
||||
case HTP_OP_SQRT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SUM_ROWS:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_sum_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
case HTP_OP_UNARY_GELU:
|
||||
if (n_bufs != 2) {
|
||||
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
case HTP_OP_SOFTMAX:
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||
FARF(ERROR, "Bad act-req buffer list");
|
||||
continue;
|
||||
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ARGSORT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad argsort-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_argsort_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,115 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
sum_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
||||
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_softplus_f32(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
|
||||
@@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_REPEAT:
|
||||
return true;
|
||||
default:
|
||||
return ggml_op_is_empty(op);
|
||||
@@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 64;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
|
||||
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const int64_t n = ggml_nelements(op);
|
||||
int op_num = -1;
|
||||
|
||||
const char * op_str = "undefined";
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE: op_str = "scale"; break;
|
||||
case GGML_OP_FILL: op_str = "fill"; break;
|
||||
case GGML_OP_CLAMP: op_str = "clamp"; break;
|
||||
case GGML_OP_SQR: op_str = "sqr"; break;
|
||||
case GGML_OP_SQRT: op_str = "sqrt"; break;
|
||||
case GGML_OP_SIN: op_str = "sin"; break;
|
||||
case GGML_OP_COS: op_str = "cos"; break;
|
||||
case GGML_OP_LOG: op_str = "log"; break;
|
||||
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
|
||||
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
|
||||
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
|
||||
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
|
||||
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
|
||||
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
|
||||
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
|
||||
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
|
||||
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
|
||||
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
|
||||
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
|
||||
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
|
||||
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
|
||||
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
|
||||
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
|
||||
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
|
||||
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
|
||||
case GGML_UNARY_OP_STEP: op_str = "step"; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
||||
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
|
||||
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
|
||||
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
|
||||
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
|
||||
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
|
||||
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
|
||||
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
|
||||
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
|
||||
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
|
||||
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
|
||||
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
const char * suffix = "";
|
||||
if (n % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
|
||||
|
||||
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
|
||||
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.cnt = is_cnt;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
op_str = "sum_rows"; break;
|
||||
case GGML_OP_MEAN:
|
||||
op_str = "mean"; break;
|
||||
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
|
||||
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d", base, op_num);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
if (is_c4) {
|
||||
res.smem *= 4;
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1472,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_L2_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_f32");
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -1486,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
return res;
|
||||
|
||||
@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH:
|
||||
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1058,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD_ID:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_TRI:
|
||||
@@ -1087,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
@@ -1161,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -80,7 +80,9 @@
|
||||
#define FC_SSM_CONV 900
|
||||
#define FC_SOLVE_TRI 1000
|
||||
#define FC_COUNT_EQUAL 1100
|
||||
#define FC_BIN 1200
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
@@ -89,6 +91,37 @@
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
#define OP_UNARY_NUM_SCALE 10
|
||||
#define OP_UNARY_NUM_FILL 11
|
||||
#define OP_UNARY_NUM_CLAMP 12
|
||||
#define OP_UNARY_NUM_SQR 13
|
||||
#define OP_UNARY_NUM_SQRT 14
|
||||
#define OP_UNARY_NUM_SIN 15
|
||||
#define OP_UNARY_NUM_COS 16
|
||||
#define OP_UNARY_NUM_LOG 17
|
||||
#define OP_UNARY_NUM_LEAKY_RELU 18
|
||||
|
||||
#define OP_UNARY_NUM_TANH 100
|
||||
#define OP_UNARY_NUM_RELU 101
|
||||
#define OP_UNARY_NUM_SIGMOID 102
|
||||
#define OP_UNARY_NUM_GELU 103
|
||||
#define OP_UNARY_NUM_GELU_ERF 104
|
||||
#define OP_UNARY_NUM_GELU_QUICK 105
|
||||
#define OP_UNARY_NUM_SILU 106
|
||||
#define OP_UNARY_NUM_ELU 107
|
||||
#define OP_UNARY_NUM_NEG 108
|
||||
#define OP_UNARY_NUM_ABS 109
|
||||
#define OP_UNARY_NUM_SGN 110
|
||||
#define OP_UNARY_NUM_STEP 111
|
||||
#define OP_UNARY_NUM_HARDSWISH 112
|
||||
#define OP_UNARY_NUM_HARDSIGMOID 113
|
||||
#define OP_UNARY_NUM_EXP 114
|
||||
#define OP_UNARY_NUM_SOFTPLUS 115
|
||||
#define OP_UNARY_NUM_EXPM1 116
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -124,6 +157,31 @@ typedef struct {
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float slope;
|
||||
float scale;
|
||||
float bias;
|
||||
float val;
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_unary;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
@@ -181,20 +239,6 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
float scale;
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float val;
|
||||
} ggml_metal_kargs_fill;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t nk0;
|
||||
int64_t ne00;
|
||||
@@ -498,8 +542,21 @@ typedef struct {
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
@@ -881,10 +938,6 @@ typedef struct {
|
||||
int max_period;
|
||||
} ggml_metal_kargs_timestep_embedding;
|
||||
|
||||
typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
n_fuse = ggml_metal_op_acc(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_scale(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_fill(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tri(ctx, idx);
|
||||
@@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_set(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
@@ -628,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
@@ -679,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ pnb1,
|
||||
/*.nb02 =*/ pnb2,
|
||||
@@ -695,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
@@ -715,126 +707,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
int nth = 1;
|
||||
|
||||
while (2*nth < args.ne0 && nth < nth_max) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float bias;
|
||||
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_scale args = {
|
||||
/*.scale =*/ scale,
|
||||
/*.bias =*/ bias,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float val = ggml_get_op_params_f32(op, 0);
|
||||
|
||||
ggml_metal_kargs_fill args = {
|
||||
/*.val =*/ val
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_clamp args = {
|
||||
/*.min =*/ min,
|
||||
/*.max =*/ max,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -846,19 +731,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_unary args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.slope =*/ 0.0,
|
||||
/*.scale =*/ 0.0,
|
||||
/*.bias =*/ 0.0,
|
||||
/*.val =*/ 0.0,
|
||||
/*.min =*/ 0.0,
|
||||
/*.max =*/ 0.0,
|
||||
};
|
||||
|
||||
if (op->op == GGML_OP_LEAKY_RELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_SCALE) {
|
||||
args.scale = ggml_get_op_params_f32(op, 0);
|
||||
args.bias = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_FILL) {
|
||||
args.val = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_CLAMP) {
|
||||
args.min = ggml_get_op_params_f32(op, 0);
|
||||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
if (pipeline.cnt) {
|
||||
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nth = MIN(args.ne00, nth_max);
|
||||
|
||||
const int nk0 = (args.ne00 + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -969,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -990,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
nth = std::min(nth, (int) args.ne00);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -1664,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
||||
const size_t offs = ((const int32_t *) op->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
||||
|
||||
int64_t nk0 = ne10;
|
||||
if (ggml_is_quantized(op->src[1]->type)) {
|
||||
nk0 = ne10/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb10,
|
||||
/*.nb01 =*/ nb11,
|
||||
/*.nb02 =*/ nb12,
|
||||
/*.nb03 =*/ nb13,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ ggml_element_size(op),
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
bid_dst.offs += offs;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -3044,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
|
||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -4084,42 +4187,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, op->op_params, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_leaky_relu args = {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
@@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
@@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
|
||||
return x*y;
|
||||
}
|
||||
|
||||
static inline float sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float sum(float4 x) {
|
||||
return x[0] + x[1] + x[2] + x[3];
|
||||
}
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
@@ -895,6 +903,210 @@ enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
inline T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
template<typename T> T elu_approx(T x);
|
||||
|
||||
template<> inline float elu_approx<float>(float x) {
|
||||
return (x > 0.f) ? x : (exp(x) - 1);
|
||||
}
|
||||
|
||||
template<> inline float4 elu_approx<float4>(float4 x) {
|
||||
float4 res;
|
||||
|
||||
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
||||
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
||||
|
||||
template <typename T0, typename T, typename TC>
|
||||
kernel void kernel_unary_impl(
|
||||
constant ggml_metal_kargs_unary & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_unary_op
|
||||
#define FC_CNT FC_unary_cnt
|
||||
|
||||
device const T0 * src0_ptr;
|
||||
device T * dst_ptr;
|
||||
|
||||
int i0;
|
||||
|
||||
if (FC_CNT) {
|
||||
i0 = tgpig.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0);
|
||||
dst_ptr = (device T *) (dst);
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int k0 = tgpig.x/args.ne01;
|
||||
const int i01 = tgpig.x - k0*args.ne01;
|
||||
|
||||
i0 = k0*ntg.x + tpitg.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
||||
}
|
||||
|
||||
{
|
||||
//threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!FC_CNT) {
|
||||
if (i0 >= args.ne0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const TC x = (TC) src0_ptr[i0];
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
||||
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FILL) {
|
||||
dst_ptr[i0] = (T) args.val;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
||||
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQR) {
|
||||
dst_ptr[i0] = (T) (x * x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
||||
dst_ptr[i0] = (T) sqrt(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIN) {
|
||||
dst_ptr[i0] = (T) sin(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_COS) {
|
||||
dst_ptr[i0] = (T) cos(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LOG) {
|
||||
dst_ptr[i0] = (T) log(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
||||
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TANH) {
|
||||
dst_ptr[i0] = (T) precise::tanh(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_RELU) {
|
||||
dst_ptr[i0] = (T) fmax(0, x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
||||
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
||||
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SILU) {
|
||||
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ELU) {
|
||||
dst_ptr[i0] = (T) elu_approx(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_NEG) {
|
||||
dst_ptr[i0] = (T) -x;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ABS) {
|
||||
dst_ptr[i0] = (T) fabs(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SGN) {
|
||||
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_STEP) {
|
||||
dst_ptr[i0] = T(x > 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
||||
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
||||
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXP) {
|
||||
dst_ptr[i0] = (T) exp(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
||||
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
||||
// TODO: precise implementation
|
||||
dst_ptr[i0] = (T) (exp(x) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_CNT
|
||||
}
|
||||
|
||||
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
||||
|
||||
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
||||
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
||||
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
||||
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
||||
|
||||
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
@@ -1114,414 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
|
||||
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
||||
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
||||
|
||||
kernel void kernel_scale_f32(
|
||||
constant ggml_metal_kargs_scale & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_scale_f32_4(
|
||||
constant ggml_metal_kargs_scale & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32_4(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32_4(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
||||
}
|
||||
|
||||
kernel void kernel_relu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_relu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sigmoid_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
||||
}
|
||||
|
||||
kernel void kernel_sigmoid_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
||||
}
|
||||
|
||||
kernel void kernel_tanh_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = precise::tanh(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_tanh_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = precise::tanh(src0[tpig]);
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
kernel void kernel_gelu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
// This was observed with Falcon 7B and 40B models
|
||||
//
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_silu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_silu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_elu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_elu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_sqr_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqr_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_log_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = log(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_log_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = log(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_neg_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_neg_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_abs_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_abs_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sign(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sign(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_step_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = step(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_step_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = step(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_exp_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_exp_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_reglu_f32(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
@@ -1705,33 +1509,35 @@ kernel void kernel_op_sum_f32(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_sum_rows_impl(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
#define FC_OP FC_sum_rows_op
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
shmem_t[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float sumf = 0;
|
||||
T0 sumf = T0(0.0f);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
@@ -1742,23 +1548,33 @@ kernel void kernel_sum_rows(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
shmem_t[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = shmem_t[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
||||
if (is_same<float4, T0>::value) {
|
||||
dst_row[0] = sum(sumf) / (4*args.ne00);
|
||||
} else {
|
||||
dst_row[0] = sum(sumf) / args.ne00;
|
||||
}
|
||||
} else {
|
||||
dst_row[0] = sum(sumf);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
@@ -2639,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
|
||||
const short K = FC_solve_tri_k;
|
||||
const short NP = PAD2(N, NW);
|
||||
|
||||
const int32_t ne02 = args.ne02;
|
||||
const int32_t ne03 = args.ne03;
|
||||
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x*NSG + sgitg;
|
||||
@@ -2928,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
|
||||
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||
|
||||
kernel void kernel_l2_norm_f32(
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_l2_norm_impl(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
@@ -2965,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
||||
|
||||
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
||||
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -5072,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
kernel void kernel_leaky_relu_f32(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = x > 0.0f ? x : x * args.slope;
|
||||
}
|
||||
|
||||
kernel void kernel_leaky_relu_f32_4(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
||||
|
||||
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
||||
@@ -6161,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
|
||||
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
|
||||
|
||||
const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
//const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
||||
@@ -8749,7 +8554,9 @@ kernel void kernel_mul_mm(
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
@@ -8872,8 +8679,8 @@ kernel void kernel_mul_mm(
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
@@ -9122,7 +8929,9 @@ kernel void kernel_mul_mm_id(
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
@@ -9257,8 +9066,8 @@ kernel void kernel_mul_mm_id(
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
@@ -9939,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32(
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_memset(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
constant ggml_metal_kargs_memset & args,
|
||||
device T * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
|
||||
@@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_0_f32_8x_flat
|
||||
mul_mv_q4_0_f32_1d_8x_flat
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q4_1_f32
|
||||
mul_mv_q4_1_f32_flat
|
||||
mul_mv_q4_k_f32
|
||||
mul_mv_q6_k_f32
|
||||
mul_mv_q6_k_f32_flat
|
||||
mul_mv_q8_0_f32
|
||||
@@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
|
||||
gemv_moe_mxfp4_f32
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
mul
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -46,6 +46,15 @@ struct block_q4_0
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_1
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q6_K
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_1
|
||||
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q4_1(
|
||||
global struct block_q4_1 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
global half * dst_m
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * m = (global half *) dst_m + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*m = b->m;
|
||||
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_1(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global half * src_m,
|
||||
global struct block_q4_1 * dst
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * m = (global half *) src_m + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->m = *m;
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_mxfp4
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -3,80 +3,111 @@
|
||||
//------------------------------------------------------------------------------
|
||||
// expm1
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_expm1_f32_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
global const float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
global const float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16(
|
||||
global const half * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16_4(
|
||||
global const half4 * src0,
|
||||
ulong offset0,
|
||||
global half4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half4*)((global char*)src0 + offset0);
|
||||
dst = (global half4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = exp(*src_val_ptr) - 1;
|
||||
}
|
||||
*y = exp(*x) - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f16_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_expm1_f16_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = exp(*src_val_ptr) - 1;
|
||||
}
|
||||
*y = exp(*x) - 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
// Most devices have max workgroup size of 1024, so this is enough for subgroup
|
||||
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
|
||||
#define MAX_SUBGROUPS 64
|
||||
kernel void kernel_mean_f32(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i3 = get_global_id(2);
|
||||
int i2 = get_global_id(1);
|
||||
int i1 = get_global_id(0);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
|
||||
for (int i0 = 0; i0 < ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum / ne00;
|
||||
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00; i0 += lsize) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf / ne00;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mean_f32_4(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float4 sum_vec = (float4)0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
|
||||
sum_vec += src_row[i0];
|
||||
}
|
||||
|
||||
float sumf = dot(sum_vec, (float4)(1.0f));
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf / ne00;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global half * src0_m,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
float m = (float)src0_m[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 2
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
|
||||
global uchar * src0_ql,
|
||||
global uchar * src0_qh,
|
||||
global char * src0_s,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
|
||||
int ib = idx / 128; // 2 values per idx
|
||||
int iqs = idx % 128; // 0..127
|
||||
|
||||
int n = iqs / 64; // 0,1
|
||||
int b = (iqs % 64) / 32; // 0,1
|
||||
int is_b = (iqs % 16) / 8; // 0,1
|
||||
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
int is = 8 * n + qhshift + is_b; // 0..15
|
||||
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
|
||||
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
|
||||
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y(
|
||||
global const struct block_q4_1 * qb_curr,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y_flat(
|
||||
global const uchar * x,
|
||||
global const half * dh,
|
||||
global const half * mh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
float m = *mh;
|
||||
global const ushort * qs = ((global const ushort *) x + il/2);
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
// The number of scales/mins is the same as the number of blocks.
|
||||
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
|
||||
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_dm;
|
||||
global half * m = (global half *) src0_m + offset0_dm;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_K
|
||||
//------------------------------------------------------------------------------
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
|
||||
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uchar qs[QK_K/2]; // 4-bit quants
|
||||
} block_q4_K;
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // number of rows each SIMD group works on
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // SIMD group size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#undef BLOCK_STRIDE
|
||||
// number of (super) blocks each subgroup processes
|
||||
// each thread in a subgroup processes a block (32 weights)
|
||||
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
global char * src0,
|
||||
int offset0,
|
||||
global char * src1,
|
||||
int offset1,
|
||||
global char * dst,
|
||||
int offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
ushort kmask1 = 0x3f3f;
|
||||
ushort kmask2 = 0x0f0f;
|
||||
ushort kmask3 = 0xc0c0;
|
||||
|
||||
int ix = get_sub_group_local_id()/8; // super block index
|
||||
int it = get_sub_group_local_id()%8; // block index (inside super block)
|
||||
int iq = it/4; // 0 or 1 - first or second half of the super block
|
||||
int ir = it%4; // 0...3 - block index in the half super block
|
||||
|
||||
int nb = ne00/QK_K;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f};
|
||||
float all_sum;
|
||||
|
||||
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
ushort sc16[4];
|
||||
uchar * sc8 = (uchar *)sc16;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+0];
|
||||
sumy.s0 += yl[i+0];
|
||||
|
||||
yl[i+8] = y4[i+32];
|
||||
sumy.s1 += yl[i+8];
|
||||
|
||||
yh[i+0] = y4[i+128];
|
||||
sumy.s2 += yh[i+0];
|
||||
|
||||
yh[i+8] = y4[i+160];
|
||||
sumy.s3 += yh[i+8];
|
||||
}
|
||||
|
||||
global ushort * sc = (global ushort *)x[ib].scales + iq;
|
||||
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
global half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sc16[0] = sc[0] & kmask1;
|
||||
sc16[1] = sc[2] & kmask1;
|
||||
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
||||
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
||||
|
||||
global ushort * q2 = q1 + 32;
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
|
||||
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
|
||||
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
|
||||
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
|
||||
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
|
||||
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
|
||||
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
|
||||
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
|
||||
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
|
||||
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
|
||||
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
|
||||
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
|
||||
|
||||
q1 += nb01/2;
|
||||
sc += nb01/2;
|
||||
dh += nb01/2;
|
||||
}
|
||||
|
||||
y4 += BLOCK_STRIDE * QK_K;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = sub_group_reduce_add(sumf[row]);
|
||||
if (first_row + row < ne01) {
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,86 +3,114 @@
|
||||
//------------------------------------------------------------------------------
|
||||
// softplus
|
||||
//------------------------------------------------------------------------------
|
||||
inline float softplus_f32(float x){
|
||||
float ax = fabs(x);
|
||||
float m = fmax(x, 0.0f);
|
||||
return log1p(exp(-ax)) + m;
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
global const float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_softplus_f32_4(
|
||||
global const float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16(
|
||||
global const half * src0,
|
||||
ulong offset0,
|
||||
global half * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half*)((global char*)src0 + offset0);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
const float x = convert_float(src0[get_global_id(0)]);
|
||||
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16_4(
|
||||
global const half4 * src0,
|
||||
ulong offset0,
|
||||
global half4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global half4*)((global char*)src0 + offset0);
|
||||
dst = (global half4*)((global char*)dst + offsetd);
|
||||
|
||||
const float4 x = convert_float4(src0[get_global_id(0)]);
|
||||
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = softplus_f32(*src_val_ptr);
|
||||
}
|
||||
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f16_nd(
|
||||
global void * p_src0_base,
|
||||
ulong off_src0_abs,
|
||||
global void * p_dst_base,
|
||||
ulong off_dst_abs,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
kernel void kernel_softplus_f16_nc(
|
||||
global const char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
int i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int i2 = get_global_id(2);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
||||
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
|
||||
}
|
||||
const float x = convert_float(*hx);
|
||||
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
// Most devices have max workgroup size of 1024, so this is enough for subgroup
|
||||
// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
|
||||
#define MAX_SUBGROUPS 64
|
||||
kernel void kernel_sum_rows_f32(
|
||||
global float * src0,
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = (global float *)((global char *)src0 + offset0);
|
||||
dst = (global float *)((global char *)dst + offsetd);
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i3 = get_global_id(2);
|
||||
int i2 = get_global_id(1);
|
||||
int i1 = get_global_id(0);
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
|
||||
for (int i0 = 0; i0 < ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00; i0 += lsize) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_sum_rows_f32_4(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
dst = dst + offsetd;
|
||||
|
||||
const int i3 = get_group_id(2);
|
||||
const int i2 = get_group_id(1);
|
||||
const int i1 = get_group_id(0);
|
||||
|
||||
const int lid = get_local_id(0);
|
||||
const int lsize = get_local_size(0);
|
||||
|
||||
const uint sg_size = get_sub_group_size();
|
||||
const uint sg_id = get_sub_group_id();
|
||||
const uint sg_lid = get_sub_group_local_id();
|
||||
|
||||
__local float lmem[MAX_SUBGROUPS];
|
||||
|
||||
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if(sg_id == 0){
|
||||
lmem[sg_lid] = 0.0f;
|
||||
}
|
||||
|
||||
global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
|
||||
global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
float4 sum_vec = (float4)0.0f;
|
||||
|
||||
for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
|
||||
sum_vec += src_row[i0];
|
||||
}
|
||||
|
||||
float sumf = dot(sum_vec, (float4)(1.0f));
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if(sg_lid == 0){
|
||||
lmem[sg_id] = sumf;
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
sumf = lmem[sg_lid];
|
||||
sumf = sub_group_reduce_add(sumf);
|
||||
|
||||
if (lid == 0) {
|
||||
dst_row[0] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
#define VK_VENDOR_ID_APPLE 0x106b
|
||||
#define VK_VENDOR_ID_INTEL 0x8086
|
||||
#define VK_VENDOR_ID_NVIDIA 0x10de
|
||||
#define VK_VENDOR_ID_QUALCOMM 0x5143
|
||||
|
||||
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
|
||||
|
||||
@@ -687,6 +688,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_add[2][2][2];
|
||||
@@ -942,6 +944,7 @@ struct vk_mat_mat_push_constants {
|
||||
uint32_t M; uint32_t N; uint32_t K;
|
||||
uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
|
||||
uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
|
||||
uint32_t base_work_group_z; uint32_t num_batches;
|
||||
uint32_t k_split;
|
||||
uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
|
||||
uint32_t padded_N;
|
||||
@@ -961,6 +964,7 @@ struct vk_mat_vec_push_constants {
|
||||
uint32_t batch_stride_b;
|
||||
uint32_t batch_stride_d;
|
||||
uint32_t fusion_flags;
|
||||
uint32_t base_work_group_y;
|
||||
uint32_t ne02;
|
||||
uint32_t ne12;
|
||||
uint32_t broadcast2;
|
||||
@@ -4080,7 +4084,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -4181,7 +4185,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5641,6 +5646,10 @@ static void ggml_vk_instance_init() {
|
||||
driver_priorities[vk::DriverId::eMesaNvk] = 2;
|
||||
#endif
|
||||
break;
|
||||
case VK_VENDOR_ID_QUALCOMM:
|
||||
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
|
||||
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
|
||||
break;
|
||||
}
|
||||
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
||||
|
||||
@@ -6766,8 +6775,16 @@ static void ggml_vk_matmul(
|
||||
uint32_t padded_n) {
|
||||
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
||||
if (split_k == 1) {
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
||||
|
||||
uint32_t base_work_group_z = 0;
|
||||
while (base_work_group_z < batch) {
|
||||
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
|
||||
base_work_group_z += groups_z;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6781,9 +6798,17 @@ static void ggml_vk_matmul(
|
||||
uint32_t k_split = CEIL_DIV(k, split_k);
|
||||
k_split = ROUNDUP_POW2(k_split, 256);
|
||||
|
||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
// Make sure enough workgroups get assigned for split k to work
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
|
||||
|
||||
uint32_t base_work_group_z = 0;
|
||||
while (base_work_group_z < batch) {
|
||||
uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
|
||||
const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
// Make sure enough workgroups get assigned for split k to work
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
|
||||
base_work_group_z += groups_z;
|
||||
}
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
|
||||
@@ -7179,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
}
|
||||
|
||||
// Request descriptor sets
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
if (qx_needs_dequant) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
|
||||
}
|
||||
@@ -7477,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||
}
|
||||
|
||||
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
||||
@@ -7572,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
|
||||
}
|
||||
|
||||
// compute
|
||||
const vk_mat_vec_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
fusion_flags,
|
||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
|
||||
|
||||
uint32_t base_work_group_y = 0;
|
||||
while (base_work_group_y < ne12 * ne13) {
|
||||
|
||||
uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
const vk_mat_vec_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
fusion_flags, base_work_group_y,
|
||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
},
|
||||
pc, { groups_x, groups_y, groups_z });
|
||||
base_work_group_y += groups_y;
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
@@ -7825,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||
src1->nb[2] <= src1->nb[1] &&
|
||||
src1->nb[1] <= src1->nb[3] &&
|
||||
src0->ne[3] == 1 &&
|
||||
src1->ne[3] == 1) {
|
||||
src1->ne[3] == 1 &&
|
||||
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
||||
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
||||
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
|
||||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
|
||||
src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
|
||||
src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
|
||||
src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
|
||||
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
|
||||
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||
// when ne12 and ne13 are one.
|
||||
@@ -8422,6 +8457,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
|
||||
|
||||
const uint32_t qstride = hsk_pad / 4 + 2;
|
||||
const uint32_t Qf = Br * qstride * f16vec4;
|
||||
|
||||
@@ -8438,7 +8475,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
|
||||
@@ -8815,6 +8852,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SET:
|
||||
if (src0->type == src1->type && src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
|
||||
return ctx->device->pipeline_set_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -9801,16 +9844,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
|
||||
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
|
||||
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
|
||||
0,
|
||||
0.0f, 0.0f, offset,
|
||||
});
|
||||
@@ -10624,8 +10667,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
}
|
||||
|
||||
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -11543,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||
}
|
||||
}
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
@@ -12052,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
// y[i] = i % k;
|
||||
}
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
|
||||
|
||||
@@ -12500,6 +12543,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SET:
|
||||
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
@@ -14896,8 +14940,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -14960,7 +15006,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
}
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET:
|
||||
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
|
||||
case GGML_OP_CONCAT:
|
||||
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
|
||||
case GGML_OP_ADD1:
|
||||
@@ -15611,6 +15660,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_ACC) {
|
||||
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
||||
} else if (tensor->op == GGML_OP_SET) {
|
||||
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
||||
} else if (tensor->op == GGML_OP_NORM) {
|
||||
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
// false for SET, true for ACC
|
||||
layout(constant_id = 1) const bool ACC = true;
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
@@ -13,17 +16,22 @@ void main() {
|
||||
|
||||
const uint offset = p.param3;
|
||||
const uint src1_i = idx - offset;
|
||||
const uint oz = src1_i / p.nb02;
|
||||
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
|
||||
const uint ox = src1_i % p.nb01;
|
||||
const uint i3 = src1_i / p.nb03;
|
||||
const uint rem2 = src1_i - i3 * p.nb03;
|
||||
const uint i2 = rem2 / p.nb02;
|
||||
const uint rem1 = rem2 - i2 * p.nb02;
|
||||
const uint i1 = rem1 / p.nb01;
|
||||
const uint i0 = rem1 % p.nb01;
|
||||
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
||||
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
|
||||
if (ACC) {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
|
||||
} else {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
|
||||
}
|
||||
} else {
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,6 +130,7 @@ void main() {
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
@@ -137,12 +138,25 @@ void main() {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
@@ -260,6 +274,9 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
barrier();
|
||||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
|
||||
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
@@ -213,6 +215,19 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -176,7 +176,14 @@ void main() {
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
@@ -184,7 +191,14 @@ void main() {
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
@@ -8,19 +8,22 @@
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -35,7 +38,7 @@ void main() {
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user