mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-04 19:45:57 +02:00
Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3d804dec76 | |||
| ffd0821c57 | |||
| 4314e56c4f | |||
| 496e5bf46b | |||
| 7919256c57 | |||
| e0449763a4 | |||
| eb7cf15a80 | |||
| 66ee4f297c | |||
| e51c47b401 | |||
| 2711d0215f | |||
| f0d4b29edf | |||
| 815857791d | |||
| 1a0e87d291 | |||
| d2e518e9b4 | |||
| b636228c0a | |||
| 325afb370a | |||
| 794fe23f29 | |||
| cf8cc856d7 | |||
| d0c08040b6 | |||
| be5ef7963f | |||
| cae9fb4361 | |||
| 7fee2889e6 | |||
| d7d1eccacc | |||
| 4bf3119d61 | |||
| f643120bad | |||
| 6e84b0ab8e | |||
| 2b8525d5c8 | |||
| a4417ddda9 | |||
| d6d24cd9ed | |||
| a5203b4465 | |||
| df984e0147 | |||
| acd38efee3 | |||
| caf773f249 | |||
| 178a7eb952 | |||
| 6f53d8a6b4 | |||
| 19f65187cb | |||
| 1d8ee06000 | |||
| 2cc9b8c32c | |||
| f35726c2fb | |||
| 4a75d19376 | |||
| 26771a1491 | |||
| ca6baf76c1 | |||
| 6e264a905b | |||
| 49b0e3cec4 | |||
| 20a758155b | |||
| 00c24acb2a | |||
| 466ea66f33 | |||
| 5f0db9522f | |||
| c5d9effb49 | |||
| 9fbadaef4f | |||
| 9755129c27 | |||
| a07c2c8a52 | |||
| 8137b4bb2b | |||
| 1af6945eb0 | |||
| 01f37edf1a | |||
| c07e87f38b | |||
| 564804b79b | |||
| 05f63cc9ee | |||
| f7fb43cd0b | |||
| 5845661640 | |||
| f211d1dc10 | |||
| 955a6c2d91 | |||
| 1971adf55e | |||
| 5245729e33 | |||
| 6152129d05 | |||
| 16d3df7ab0 | |||
| 12c2bdf2de | |||
| c64d2becb1 | |||
| 96f4053934 | |||
| a94f3b2727 | |||
| 3e3357fd77 | |||
| 6171c9d258 | |||
| e28245f35f | |||
| 6da5bec81c | |||
| 2e2f8f093c | |||
| 2139667ec4 | |||
| 80d0d6b4b7 | |||
| aea8ddd516 | |||
| 9f7add1cde | |||
| 90d987b105 | |||
| a4251edd6f |
+12
-1
@@ -2,6 +2,10 @@ ARG UBUNTU_VERSION=22.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
|
||||
ARG GGML_CPU_ARM_ARCH=armv8-a
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
@@ -9,7 +13,14 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \
|
||||
else \
|
||||
echo "Unsupported architecture"; \
|
||||
exit 1; \
|
||||
fi && \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
|
||||
+9
-1
@@ -13,9 +13,13 @@ elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then
|
||||
exec ./llama-quantize "$@"
|
||||
elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then
|
||||
exec ./llama-cli "$@"
|
||||
elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then
|
||||
exec ./llama-bench "$@"
|
||||
elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then
|
||||
exec ./llama-perplexity "$@"
|
||||
elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then
|
||||
echo "Converting PTH to GGML..."
|
||||
for i in `ls $1/$2/ggml-model-f16.bin*`; do
|
||||
for i in $(ls $1/$2/ggml-model-f16.bin*); do
|
||||
if [ -f "${i/f16/q4_0}" ]; then
|
||||
echo "Skip model quantization, it already exists: ${i/f16/q4_0}"
|
||||
else
|
||||
@@ -30,6 +34,10 @@ else
|
||||
echo "Available commands: "
|
||||
echo " --run (-r): Run a model previously converted into ggml"
|
||||
echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512"
|
||||
echo " --bench (-b): Benchmark the performance of the inference for various parameters."
|
||||
echo " ex: -m model.gguf"
|
||||
echo " --perplexity (-p): Measure the perplexity of a model over a given text."
|
||||
echo " ex: -m model.gguf -f file.txt"
|
||||
echo " --convert (-c): Convert a llama model into ggml"
|
||||
echo " ex: --outtype f16 \"/models/7B/\" "
|
||||
echo " --quantize (-q): Optimize with quantization process ggml"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG UBUNTU_VERSION=jammy
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
@@ -7,7 +7,7 @@ RUN apt update && apt install -y git build-essential cmake wget
|
||||
|
||||
# Install Vulkan SDK and cURL
|
||||
RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \
|
||||
apt update -y && \
|
||||
apt-get install -y vulkan-sdk libcurl4-openssl-dev curl
|
||||
|
||||
@@ -34,7 +34,7 @@ RUN mkdir -p /app/full \
|
||||
FROM ubuntu:$UBUNTU_VERSION AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl\
|
||||
&& apt-get install -y libgomp1 curl libvulkan-dev \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
@@ -55,8 +55,9 @@ RUN apt-get update \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
+102
-52
@@ -56,6 +56,7 @@ jobs:
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
@@ -120,6 +121,7 @@ jobs:
|
||||
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
|
||||
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
@@ -160,8 +162,8 @@ jobs:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
|
||||
name: llama-bin-macos-x64.zip
|
||||
|
||||
ubuntu-latest-cmake:
|
||||
runs-on: ubuntu-latest
|
||||
ubuntu-cpu-cmake:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -181,7 +183,10 @@ jobs:
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON
|
||||
cmake .. \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
@@ -256,7 +261,10 @@ jobs:
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
cmake .. \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
|
||||
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
@@ -265,7 +273,11 @@ jobs:
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF
|
||||
cmake .. \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
|
||||
-DGGML_OPENMP=OFF
|
||||
cmake --build . --config ${{ matrix.build_type }} -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
@@ -295,7 +307,8 @@ jobs:
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_RPC=ON ..
|
||||
cmake .. \
|
||||
-DGGML_RPC=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
@@ -325,14 +338,16 @@ jobs:
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_VULKAN=ON ..
|
||||
cmake .. \
|
||||
-DGGML_VULKAN=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 1800
|
||||
|
||||
ubuntu-22-cmake-hip:
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -352,13 +367,18 @@ jobs:
|
||||
- name: Build with native CMake HIP support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIP=ON
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
- name: Build with legacy HIP support
|
||||
id: cmake_build_legacy_hip
|
||||
run: |
|
||||
cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIP=ON
|
||||
cmake -B build2 -S . \
|
||||
-DCMAKE_C_COMPILER=hipcc \
|
||||
-DCMAKE_CXX_COMPILER=hipcc \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build2 --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-musa:
|
||||
@@ -379,7 +399,8 @@ jobs:
|
||||
- name: Build with native CMake MUSA support
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -S . -DGGML_MUSA=ON
|
||||
cmake -B build -S . \
|
||||
-DGGML_MUSA=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl:
|
||||
@@ -420,7 +441,10 @@ jobs:
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ..
|
||||
cmake .. \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl-fp16:
|
||||
@@ -461,42 +485,13 @@ jobs:
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON ..
|
||||
cmake .. \
|
||||
-DGGML_SYCL=ON \
|
||||
-DCMAKE_C_COMPILER=icx \
|
||||
-DCMAKE_CXX_COMPILER=icpx \
|
||||
-DGGML_SYCL_F16=ON
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
# TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know
|
||||
# how to debug it.
|
||||
# ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584
|
||||
# would be great if we fix these
|
||||
macOS-latest-cmake:
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
run: |
|
||||
brew update
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF ..
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
cd build
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-ios:
|
||||
runs-on: macos-latest
|
||||
|
||||
@@ -619,6 +614,7 @@ jobs:
|
||||
msystem: ${{matrix.sys}}
|
||||
install: >-
|
||||
base-devel
|
||||
git
|
||||
mingw-w64-${{matrix.env}}-toolchain
|
||||
mingw-w64-${{matrix.env}}-cmake
|
||||
mingw-w64-${{matrix.env}}-openblas
|
||||
@@ -827,7 +823,13 @@ jobs:
|
||||
|
||||
- name: Build with CMake
|
||||
run: |
|
||||
cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=89-real -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined -DLLAMA_FATAL_WARNINGS=ON
|
||||
cmake -S . -B build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_CUDA_ARCHITECTURES=89-real \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CUDA=ON
|
||||
cmake --build build
|
||||
|
||||
windows-2019-cmake-cuda:
|
||||
@@ -916,7 +918,11 @@ jobs:
|
||||
shell: cmd
|
||||
run: |
|
||||
call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
|
||||
cmake -S . -B build -G "Ninja Multi-Config" -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DGGML_RPC=ON
|
||||
cmake -S . -B build -G "Ninja Multi-Config" ^
|
||||
-DLLAMA_BUILD_SERVER=ON ^
|
||||
-DGGML_NATIVE=OFF ^
|
||||
-DGGML_CUDA=ON ^
|
||||
-DGGML_RPC=ON
|
||||
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
|
||||
cmake --build build --config Release -j %NINJA_JOBS% -t ggml
|
||||
cmake --build build --config Release
|
||||
@@ -1069,7 +1075,12 @@ jobs:
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DGGML_RPC=ON
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
windows-latest-cmake-hip-release:
|
||||
@@ -1107,7 +1118,13 @@ jobs:
|
||||
run: |
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIP=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS=${{ matrix.gpu_target }} -DGGML_RPC=ON
|
||||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DAMDGPU_TARGETS=${{ matrix.gpu_target }} `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
md "build\bin\rocblas\library\"
|
||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
||||
@@ -1201,8 +1218,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
needs:
|
||||
- ubuntu-latest-cmake
|
||||
- macOS-latest-cmake
|
||||
- ubuntu-cpu-cmake
|
||||
- windows-latest-cmake
|
||||
- windows-2019-cmake-cuda
|
||||
- windows-latest-cmake-hip-release
|
||||
@@ -1461,3 +1477,37 @@ jobs:
|
||||
# popd
|
||||
# emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }}
|
||||
# make
|
||||
|
||||
openEuler-latest-cmake-cann:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
runs-on: ubuntu-24.04-arm
|
||||
strategy:
|
||||
matrix:
|
||||
cann:
|
||||
- '8.0.rc3.beta1-910b-openeuler22.03-py3.10'
|
||||
device:
|
||||
- 'ascend910b3'
|
||||
build:
|
||||
- 'Release'
|
||||
container: ascendai/cann:${{ matrix.cann }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dependencies
|
||||
run: |
|
||||
yum update -y
|
||||
yum install -y git gcc gcc-c++ make cmake
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH}
|
||||
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=${{ matrix.build }} \
|
||||
-DGGML_CANN=on \
|
||||
-DSOC_TYPE=${{ matrix.device }}
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
@@ -28,10 +28,11 @@ jobs:
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Hub
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-22.04
|
||||
env:
|
||||
COMMIT_SHA: ${{ github.sha }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
# Multi-stage build
|
||||
|
||||
+8
-18
@@ -16,6 +16,7 @@ endif()
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||
set(LLAMA_STANDALONE ON)
|
||||
@@ -49,6 +50,8 @@ endif()
|
||||
if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
#
|
||||
@@ -185,27 +188,14 @@ set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location o
|
||||
set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
# At the moment some compile definitions are placed within the ggml/src
|
||||
# directory but not exported on the `ggml` target. This could be improved by
|
||||
# determining _precisely_ which defines are necessary for the llama-config
|
||||
# package.
|
||||
#
|
||||
set(GGML_TRANSIENT_DEFINES)
|
||||
get_target_property(GGML_DIRECTORY ggml SOURCE_DIR)
|
||||
get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS)
|
||||
if (GGML_DIR_DEFINES)
|
||||
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_DIR_DEFINES})
|
||||
endif()
|
||||
get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS)
|
||||
if (GGML_TARGET_DEFINES)
|
||||
list(APPEND GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES})
|
||||
endif()
|
||||
get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES)
|
||||
# all public headers
|
||||
set(LLAMA_PUBLIC_HEADERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h)
|
||||
set_target_properties(llama PROPERTIES PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
||||
|
||||
set_target_properties(llama
|
||||
PROPERTIES
|
||||
PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}")
|
||||
|
||||
install(TARGETS llama LIBRARY PUBLIC_HEADER)
|
||||
|
||||
configure_package_config_file(
|
||||
|
||||
@@ -1361,7 +1361,9 @@ llama-server: \
|
||||
examples/server/httplib.h \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
common/chat-template.hpp \
|
||||
common/json.hpp \
|
||||
common/minja.hpp \
|
||||
$(OBJ_ALL)
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
|
||||
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||
|
||||
@@ -16,7 +16,10 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
|
||||
|
||||
## Hot topics
|
||||
|
||||
- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123
|
||||
- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427
|
||||
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
|
||||
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
|
||||
- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123
|
||||
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669
|
||||
- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)
|
||||
|
||||
@@ -419,7 +422,7 @@ To learn more about model quantization, [read this documentation](examples/quant
|
||||
|
||||
</details>
|
||||
|
||||
[^1]: [examples/perplexity/README.md](examples/perplexity/README.md)
|
||||
[^1]: [examples/perplexity/README.md](./examples/perplexity/README.md)
|
||||
[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity)
|
||||
|
||||
## [`llama-bench`](examples/llama-bench)
|
||||
|
||||
@@ -44,7 +44,7 @@ if(MSVC)
|
||||
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
|
||||
else()
|
||||
execute_process(
|
||||
COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER}
|
||||
COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER}
|
||||
OUTPUT_VARIABLE OUT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
+4
-152
@@ -3,159 +3,13 @@ set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@)
|
||||
set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@)
|
||||
set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@)
|
||||
|
||||
set(GGML_STATIC @GGML_STATIC@)
|
||||
set(GGML_NATIVE @GGML_NATIVE@)
|
||||
set(GGML_LTO @GGML_LTO@)
|
||||
set(GGML_CCACHE @GGML_CCACHE@)
|
||||
set(GGML_AVX @GGML_AVX@)
|
||||
set(GGML_AVX2 @GGML_AVX2@)
|
||||
set(GGML_AVX512 @GGML_AVX512@)
|
||||
set(GGML_AVX512_VBMI @GGML_AVX512_VBMI@)
|
||||
set(GGML_AVX512_VNNI @GGML_AVX512_VNNI@)
|
||||
set(GGML_AVX512_BF16 @GGML_AVX512_BF16@)
|
||||
set(GGML_AMX_TILE @GGML_AMX_TILE@)
|
||||
set(GGML_AMX_INT8 @GGML_AMX_INT8@)
|
||||
set(GGML_AMX_BF16 @GGML_AMX_BF16@)
|
||||
set(GGML_FMA @GGML_FMA@)
|
||||
set(GGML_LASX @GGML_LASX@)
|
||||
set(GGML_LSX @GGML_LSX@)
|
||||
set(GGML_RVV @GGML_RVV@)
|
||||
set(GGML_SVE @GGML_SVE@)
|
||||
|
||||
set(GGML_ACCELERATE @GGML_ACCELERATE@)
|
||||
set(GGML_OPENMP @GGML_OPENMP@)
|
||||
set(GGML_CPU_HBM @GGML_CPU_HBM@)
|
||||
set(GGML_BLAS_VENDOR @GGML_BLAS_VENDOR@)
|
||||
|
||||
set(GGML_CUDA_FORCE_MMQ @GGML_CUDA_FORCE_MMQ@)
|
||||
set(GGML_CUDA_FORCE_CUBLAS @GGML_CUDA_FORCE_CUBLAS@)
|
||||
set(GGML_CUDA_F16 @GGML_CUDA_F16@)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE @GGML_CUDA_PEER_MAX_BATCH_SIZE@)
|
||||
set(GGML_CUDA_NO_PEER_COPY @GGML_CUDA_NO_PEER_COPY@)
|
||||
set(GGML_CUDA_NO_VMM @GGML_CUDA_NO_VMM@)
|
||||
set(GGML_CUDA_FA_ALL_QUANTS @GGML_CUDA_FA_ALL_QUANTS@)
|
||||
set(GGML_CUDA_GRAPHS @GGML_CUDA_GRAPHS@)
|
||||
|
||||
set(GGML_HIP_UMA @GGML_HIP_UMA@)
|
||||
|
||||
set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@)
|
||||
set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@)
|
||||
set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@)
|
||||
set(GGML_VULKAN_SHADER_DEBUG_INFO @GGML_VULKAN_SHADER_DEBUG_INFO@)
|
||||
set(GGML_VULKAN_PERF @GGML_VULKAN_PERF@)
|
||||
set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@)
|
||||
set(GGML_VULKAN_RUN_TESTS @GGML_VULKAN_RUN_TESTS@)
|
||||
|
||||
set(GGML_METAL_USE_BF16 @GGML_METAL_USE_BF16@)
|
||||
set(GGML_METAL_NDEBUG @GGML_METAL_NDEBUG@)
|
||||
set(GGML_METAL_SHADER_DEBUG @GGML_METAL_SHADER_DEBUG@)
|
||||
set(GGML_METAL_EMBED_LIBRARY @GGML_METAL_EMBED_LIBRARY@)
|
||||
set(GGML_METAL_MACOSX_VERSION_MIN @GGML_METAL_MACOSX_VERSION_MIN@)
|
||||
set(GGML_METAL_STD @GGML_METAL_STD@)
|
||||
|
||||
set(GGML_SYCL_F16 @GGML_SYCL_F16@)
|
||||
set(GGML_SYCL_TARGET @GGML_SYCL_TARGET@)
|
||||
set(GGML_SYCL_DEVICE_ARCH @GGML_SYCL_DEVICE_ARCH@)
|
||||
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@")
|
||||
set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@")
|
||||
set(_llama_link_deps "")
|
||||
set(_llama_link_opts "")
|
||||
foreach(_ggml_lib ggml ggml-base)
|
||||
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
|
||||
find_library(${_ggml_lib_var} ${_ggml_lib}
|
||||
REQUIRED
|
||||
HINTS ${LLAMA_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
|
||||
message(STATUS "Found ${${_ggml_lib_var}}")
|
||||
endforeach()
|
||||
|
||||
foreach(backend amx blas cann cpu cuda hip kompute metal musa rpc sycl vulkan)
|
||||
string(TOUPPER "GGML_${backend}" backend_id)
|
||||
set(_ggml_lib "ggml-${backend}")
|
||||
string(REPLACE "-" "_" _ggml_lib_var "${_ggml_lib}_LIBRARY")
|
||||
|
||||
find_library(${_ggml_lib_var} ${_ggml_lib}
|
||||
HINTS ${LLAMA_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH
|
||||
)
|
||||
if(${_ggml_lib_var})
|
||||
list(APPEND _llama_link_deps "${${_ggml_lib_var}}")
|
||||
set(${backend_id} ON)
|
||||
message(STATUS "Found backend ${${_ggml_lib_var}}")
|
||||
else()
|
||||
set(${backend_id} OFF)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (NOT LLAMA_SHARED_LIB)
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
list(APPEND _llama_link_deps ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
list(APPEND _llama_link_deps OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
list(APPEND _llama_link_deps memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
list(APPEND _llama_link_deps ${BLAS_LIBRARIES})
|
||||
list(APPEND _llama_link_opts ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
list(APPEND _llama_link_deps ${FOUNDATION_LIBRARY}
|
||||
${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND _llama_link_deps Vulkan::Vulkan)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND _llama_link_deps hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND _llama_link_deps DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
list(APPEND _llama_link_deps IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake)
|
||||
|
||||
find_library(llama_LIBRARY llama
|
||||
REQUIRED
|
||||
@@ -167,12 +21,10 @@ add_library(llama UNKNOWN IMPORTED)
|
||||
set_target_properties(llama
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}"
|
||||
INTERFACE_LINK_LIBRARIES "${_llama_link_deps}"
|
||||
INTERFACE_LINK_OPTIONS "${_llama_link_opts}"
|
||||
INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}"
|
||||
INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${llama_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES cxx_std_11
|
||||
POSITION_INDEPENDENT_CODE ON )
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
check_required_components(Llama)
|
||||
|
||||
@@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
|
||||
arg.cpp
|
||||
arg.h
|
||||
base64.hpp
|
||||
chat-template.hpp
|
||||
common.cpp
|
||||
common.h
|
||||
console.cpp
|
||||
@@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
|
||||
json.hpp
|
||||
log.cpp
|
||||
log.h
|
||||
minja.hpp
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
sampling.cpp
|
||||
|
||||
+49
-12
@@ -133,7 +133,8 @@ static void common_params_handle_model_default(
|
||||
const std::string & model_url,
|
||||
std::string & hf_repo,
|
||||
std::string & hf_file,
|
||||
const std::string & hf_token) {
|
||||
const std::string & hf_token,
|
||||
const std::string & model_default) {
|
||||
if (!hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
if (hf_file.empty()) {
|
||||
@@ -163,7 +164,7 @@ static void common_params_handle_model_default(
|
||||
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
} else if (model.empty()) {
|
||||
model = DEFAULT_MODEL_PATH;
|
||||
model = model_default;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
}
|
||||
|
||||
// TODO: refactor model params in a common struct
|
||||
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
|
||||
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
|
||||
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH);
|
||||
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
|
||||
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");
|
||||
|
||||
if (params.escape) {
|
||||
string_process_escapes(params.prompt);
|
||||
@@ -323,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||
}
|
||||
|
||||
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
params.chat_template.c_str(),
|
||||
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
|
||||
));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -867,7 +877,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.warmup = false;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--spm-infill"},
|
||||
string_format(
|
||||
@@ -1629,6 +1639,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hff", "--hf-file"}, "FILE",
|
||||
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
|
||||
@@ -1938,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--jinja"},
|
||||
"use jinja template for chat (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
"set custom jinja chat template (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (!common_chat_verify_template(value)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s\n"
|
||||
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
|
||||
value.c_str()
|
||||
));
|
||||
}
|
||||
params.chat_template = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
|
||||
string_format(
|
||||
"set custom jinja chat template file (default: template taken from model's metadata)\n"
|
||||
"if suffix/prefix are specified, template will be disabled\n"
|
||||
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
|
||||
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::ifstream file(value);
|
||||
if (!file) {
|
||||
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
|
||||
}
|
||||
std::copy(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(params.chat_template));
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
|
||||
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
/*
|
||||
Copyright 2024 Google LLC
|
||||
|
||||
Use of this source code is governed by an MIT-style
|
||||
license that can be found in the LICENSE file or at
|
||||
https://opensource.org/licenses/MIT.
|
||||
*/
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "minja.hpp"
|
||||
#include <json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace minja {
|
||||
|
||||
struct chat_template_caps {
|
||||
bool supports_tools = false;
|
||||
bool supports_tool_calls = false;
|
||||
bool supports_tool_responses = false;
|
||||
bool supports_system_role = false;
|
||||
bool supports_parallel_tool_calls = false;
|
||||
bool supports_tool_call_id = false;
|
||||
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
|
||||
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
||||
bool requires_object_arguments = false;
|
||||
// CohereForAI/c4ai-command-r-plus simple variant
|
||||
bool requires_non_null_content = false;
|
||||
// MiniMaxAI/MiniMax-Text-01 special
|
||||
bool requires_typed_content = false;
|
||||
};
|
||||
|
||||
class chat_template {
|
||||
|
||||
private:
|
||||
chat_template_caps caps_;
|
||||
std::string source_;
|
||||
std::string bos_token_;
|
||||
std::string eos_token_;
|
||||
std::shared_ptr<minja::TemplateNode> template_root_;
|
||||
|
||||
std::string try_raw_render(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
||||
{
|
||||
try {
|
||||
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
|
||||
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
|
||||
return prompt;
|
||||
} catch (const std::exception & e) {
|
||||
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
||||
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
||||
{
|
||||
template_root_ = minja::Parser::parse(source_, {
|
||||
/* .trim_blocks = */ true,
|
||||
/* .lstrip_blocks = */ true,
|
||||
/* .keep_trailing_newline = */ false,
|
||||
});
|
||||
|
||||
auto contains = [](const std::string & haystack, const std::string & needle) {
|
||||
return haystack.find(needle) != std::string::npos;
|
||||
};
|
||||
|
||||
const std::string user_needle = "<User Needle>";
|
||||
const std::string sys_needle = "<System Needle>";
|
||||
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
|
||||
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
|
||||
|
||||
caps_.requires_typed_content =
|
||||
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
|
||||
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
|
||||
|
||||
const auto dummy_user_msg = caps_.requires_typed_content
|
||||
? dummy_typed_user_msg
|
||||
: dummy_str_user_msg;
|
||||
const json needle_system_msg = {
|
||||
{"role", "system"},
|
||||
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
|
||||
};
|
||||
|
||||
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
|
||||
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg
|
||||
}), json::array({
|
||||
{
|
||||
{"name", "some_tool"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", "some_tool"},
|
||||
{"description", "Some tool."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"arg", {
|
||||
{"type", "string"},
|
||||
{"description", "Some argument."},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({ "arg" })},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}), false);
|
||||
caps_.supports_tools = contains(out, "some_tool");
|
||||
|
||||
auto make_tool_calls_msg = [&](const json & tool_calls) {
|
||||
return json {
|
||||
{"role", "assistant"},
|
||||
{"content", nullptr},
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
};
|
||||
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
|
||||
return json {
|
||||
{"id", "call_1___"},
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"arguments", arguments},
|
||||
{"name", tool_name},
|
||||
}},
|
||||
};
|
||||
};
|
||||
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
||||
|
||||
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
||||
}), {}, false);
|
||||
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
||||
|
||||
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
||||
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
||||
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
|
||||
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
|
||||
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
|
||||
|
||||
if (caps_.supports_tool_calls) {
|
||||
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
|
||||
auto tc1 = make_tool_call("test_tool1", dummy_args);
|
||||
auto tc2 = make_tool_call("test_tool2", dummy_args);
|
||||
auto out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1, tc2})),
|
||||
}), {}, false);
|
||||
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
|
||||
|
||||
out = try_raw_render(json::array({
|
||||
dummy_user_msg,
|
||||
make_tool_calls_msg(json::array({tc1})),
|
||||
{
|
||||
{"role", "tool"},
|
||||
{"name", "test_tool1"},
|
||||
{"content", "Some response!"},
|
||||
{"tool_call_id", "call_911_"},
|
||||
}
|
||||
}), {}, false);
|
||||
caps_.supports_tool_responses = contains(out, "Some response!");
|
||||
caps_.supports_tool_call_id = contains(out, "call_911_");
|
||||
}
|
||||
}
|
||||
|
||||
const std::string & source() const { return source_; }
|
||||
const std::string & bos_token() const { return bos_token_; }
|
||||
const std::string & eos_token() const { return eos_token_; }
|
||||
const chat_template_caps & original_caps() const { return caps_; }
|
||||
|
||||
std::string apply(
|
||||
const nlohmann::ordered_json & messages,
|
||||
const nlohmann::ordered_json & tools,
|
||||
bool add_generation_prompt,
|
||||
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
||||
bool adjust_inputs = true) const
|
||||
{
|
||||
json actual_messages;
|
||||
|
||||
auto needs_adjustments = adjust_inputs && (false
|
||||
|| !caps_.supports_system_role
|
||||
|| !caps_.supports_tools
|
||||
|| !caps_.supports_tool_responses
|
||||
|| !caps_.supports_tool_calls
|
||||
|| caps_.requires_object_arguments
|
||||
|| caps_.requires_typed_content
|
||||
);
|
||||
if (needs_adjustments) {
|
||||
actual_messages = json::array();
|
||||
|
||||
auto add_message = [&](const json & msg) {
|
||||
if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
||||
actual_messages.push_back({
|
||||
{"role", msg.at("role")},
|
||||
{"content", {{
|
||||
{"type", "text"},
|
||||
{"text", msg.at("content")},
|
||||
}}},
|
||||
});
|
||||
} else {
|
||||
actual_messages.push_back(msg);
|
||||
}
|
||||
};
|
||||
|
||||
std::string pending_system;
|
||||
auto flush_sys = [&]() {
|
||||
if (!pending_system.empty()) {
|
||||
add_message({
|
||||
{"role", "user"},
|
||||
{"content", pending_system},
|
||||
});
|
||||
pending_system.clear();
|
||||
}
|
||||
};
|
||||
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
|
||||
|
||||
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
|
||||
auto message = message_;
|
||||
if (!message.contains("role") || !message.contains("content")) {
|
||||
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
||||
}
|
||||
std::string role = message.at("role");
|
||||
|
||||
if (message.contains("tool_calls")) {
|
||||
if (caps_.requires_object_arguments || !caps_.supports_tool_calls) {
|
||||
for (auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call["type"] == "function") {
|
||||
auto & function = tool_call.at("function");
|
||||
auto & arguments = function.at("arguments");
|
||||
if (arguments.is_string()) {
|
||||
try {
|
||||
arguments = json::parse(arguments.get<std::string>());
|
||||
} catch (const std::exception & ecvt) {
|
||||
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!caps_.supports_tool_calls) {
|
||||
auto content = message.at("content");
|
||||
auto tool_calls = json::array();
|
||||
for (const auto & tool_call : message.at("tool_calls")) {
|
||||
if (tool_call.at("type") != "function") {
|
||||
continue;
|
||||
}
|
||||
const auto & function = tool_call.at("function");
|
||||
auto tc = json {
|
||||
{"name", function.at("name")},
|
||||
{"arguments", function.at("arguments")},
|
||||
};
|
||||
if (tool_call.contains("id")) {
|
||||
tc["id"] = tool_call["id"];
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
auto obj = json {
|
||||
{"tool_calls", tool_calls},
|
||||
};
|
||||
if (!content.is_null() && content != "") {
|
||||
obj["content"] = content;
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("tool_calls");
|
||||
}
|
||||
}
|
||||
if (!caps_.supports_tool_responses && role == "tool") {
|
||||
message["role"] = "user";
|
||||
auto obj = json {
|
||||
{"tool_response", {
|
||||
{"tool", message.at("name")},
|
||||
{"content", message.at("content")},
|
||||
}},
|
||||
};
|
||||
if (message.contains("tool_call_id")) {
|
||||
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
||||
}
|
||||
message["content"] = obj.dump(2);
|
||||
message.erase("name");
|
||||
}
|
||||
|
||||
if (!message["content"].is_null() && !caps_.supports_system_role) {
|
||||
std::string content = message.at("content");
|
||||
if (role == "system") {
|
||||
if (!pending_system.empty()) pending_system += "\n";
|
||||
pending_system += content;
|
||||
continue;
|
||||
} else {
|
||||
if (role == "user") {
|
||||
if (!pending_system.empty()) {
|
||||
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
||||
pending_system.clear();
|
||||
}
|
||||
} else {
|
||||
flush_sys();
|
||||
}
|
||||
}
|
||||
}
|
||||
add_message(message);
|
||||
}
|
||||
if (!caps_.supports_system_role) {
|
||||
flush_sys();
|
||||
}
|
||||
} else {
|
||||
actual_messages = messages;
|
||||
}
|
||||
|
||||
auto context = minja::Context::make(json({
|
||||
{"messages", actual_messages},
|
||||
{"add_generation_prompt", add_generation_prompt},
|
||||
{"bos_token", bos_token_},
|
||||
{"eos_token", eos_token_},
|
||||
}));
|
||||
|
||||
if (!tools.is_null()) {
|
||||
auto tools_val = minja::Value(tools);
|
||||
context->set("tools", tools_val);
|
||||
}
|
||||
if (!extra_context.is_null()) {
|
||||
for (auto & kv : extra_context.items()) {
|
||||
minja::Value val(kv.value());
|
||||
context->set(kv.key(), val);
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = template_root_->render(context);
|
||||
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
|
||||
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
|
||||
return ret;
|
||||
}
|
||||
|
||||
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
||||
json messages_with_system = messages;
|
||||
|
||||
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
||||
std::string existing_system = messages_with_system.at(0).at("content");
|
||||
messages_with_system[0] = json {
|
||||
{"role", "system"},
|
||||
{"content", existing_system + "\n" + system_prompt},
|
||||
};
|
||||
} else {
|
||||
messages_with_system.insert(messages_with_system.begin(), json {
|
||||
{"role", "system"},
|
||||
{"content", system_prompt},
|
||||
});
|
||||
}
|
||||
return messages_with_system;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace minja
|
||||
+136
-32
@@ -12,6 +12,7 @@
|
||||
#include "json.hpp"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
@@ -483,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
|
||||
s = std::move(builder);
|
||||
}
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
if (i > 0) {
|
||||
result << separator;
|
||||
}
|
||||
result << values[i];
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
|
||||
std::vector<std::string> parts;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
parts.push_back(str.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
parts.push_back(str.substr(start));
|
||||
|
||||
return parts;
|
||||
}
|
||||
|
||||
std::string string_repeat(const std::string & str, size_t n) {
|
||||
if (n == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result.reserve(str.length() * n);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
result += str;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string string_from(bool value) {
|
||||
return value ? "true" : "false";
|
||||
}
|
||||
@@ -1728,67 +1771,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
||||
// Chat template utils
|
||||
//
|
||||
|
||||
std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
||||
const char * ptr_tmpl = llama_model_chat_template(model);
|
||||
return ptr_tmpl == nullptr ? "" : ptr_tmpl;
|
||||
}
|
||||
|
||||
bool common_chat_verify_template(const std::string & tmpl) {
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
try {
|
||||
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
|
||||
chat_template.apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||
return res >= 0;
|
||||
}
|
||||
|
||||
std::string common_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
std::string common_chat_apply_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & msgs,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
auto messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
||||
}
|
||||
return tmpl.apply(messages, /* tools= */ json(), add_ass);
|
||||
}
|
||||
|
||||
int alloc_size = 0;
|
||||
bool fallback = false; // indicate if we must fallback to default chatml
|
||||
std::vector<llama_chat_message> chat;
|
||||
for (const auto & msg : msgs) {
|
||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
||||
}
|
||||
|
||||
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
|
||||
std::vector<char> buf(alloc_size);
|
||||
|
||||
// run the first time to get the total output length
|
||||
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
|
||||
// error: chat template is not supported
|
||||
if (res < 0) {
|
||||
if (ptr_tmpl != nullptr) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
}
|
||||
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
fallback = true;
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
if ((size_t) res > buf.size()) {
|
||||
buf.resize(res);
|
||||
res = llama_chat_apply_template(
|
||||
fallback ? "chatml" : ptr_tmpl,
|
||||
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
||||
}
|
||||
|
||||
std::string formatted_chat(buf.data(), res);
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
std::string common_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
std::string common_chat_format_single(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass) {
|
||||
bool add_ass,
|
||||
bool use_jinja) {
|
||||
std::ostringstream ss;
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
|
||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
|
||||
std::vector<common_chat_msg> chat_new(past_msg);
|
||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||
@@ -1796,21 +1847,74 @@ std::string common_chat_format_single(const struct llama_model * model,
|
||||
};
|
||||
// format chat with new_msg
|
||||
chat_new.push_back(new_msg);
|
||||
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
|
||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
|
||||
// get the diff part
|
||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::string common_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl) {
|
||||
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
||||
std::vector<common_chat_msg> msgs = {
|
||||
{"system", "You are a helpful assistant"},
|
||||
{"user", "Hello"},
|
||||
{"assistant", "Hi there"},
|
||||
{"user", "How are you?"},
|
||||
};
|
||||
return common_chat_apply_template(model, tmpl, msgs, true);
|
||||
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
||||
}
|
||||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
||||
{
|
||||
auto vocab = llama_model_get_vocab(model);
|
||||
std::string default_template_src = chat_template_override;
|
||||
std::string template_tool_use_src = chat_template_override;
|
||||
bool has_explicit_template = !chat_template_override.empty();
|
||||
if (chat_template_override.empty()) {
|
||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
||||
if (str) {
|
||||
default_template_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||
if (str) {
|
||||
template_tool_use_src = str;
|
||||
has_explicit_template = true;
|
||||
}
|
||||
}
|
||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||
if (!template_tool_use_src.empty()) {
|
||||
default_template_src = template_tool_use_src;
|
||||
} else {
|
||||
default_template_src = R"(
|
||||
{%- for message in messages -%}
|
||||
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- "<|im_start|>assistant\n" -}}
|
||||
{%- endif -%}
|
||||
)";
|
||||
}
|
||||
}
|
||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
||||
}
|
||||
return std::string();
|
||||
} else {
|
||||
return common_token_to_piece(vocab, token, true);
|
||||
}
|
||||
};
|
||||
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||
return {
|
||||
has_explicit_template,
|
||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
||||
template_tool_use_src.empty()
|
||||
? nullptr
|
||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
+37
-13
@@ -175,7 +175,11 @@ struct common_params_speculative {
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
|
||||
std::string model = ""; // draft model for speculative decoding // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
@@ -330,6 +334,7 @@ struct common_params {
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
|
||||
std::vector<std::string> api_keys;
|
||||
@@ -424,6 +429,10 @@ std::string string_format(const char * fmt, ...);
|
||||
std::string string_strip(const std::string & str);
|
||||
std::string string_get_sortable_timestamp();
|
||||
|
||||
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
||||
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
||||
std::string string_repeat(const std::string & str, size_t n);
|
||||
|
||||
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
template<class T>
|
||||
@@ -508,12 +517,14 @@ struct llama_model * common_load_model_from_url(
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params);
|
||||
|
||||
struct llama_model * common_load_model_from_hf(
|
||||
const std::string & repo,
|
||||
const std::string & remote_path,
|
||||
const std::string & local_path,
|
||||
const std::string & hf_token,
|
||||
const struct llama_model_params & params);
|
||||
|
||||
std::pair<std::string, std::string> common_get_hf_file(
|
||||
const std::string & hf_repo_with_tag,
|
||||
const std::string & hf_token);
|
||||
@@ -597,30 +608,43 @@ struct common_chat_msg {
|
||||
std::string content;
|
||||
};
|
||||
|
||||
// Get the built-in chat template for the model. Return empty string if not present.
|
||||
std::string common_get_builtin_chat_template(const struct llama_model * model);
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
bool common_chat_verify_template(const std::string & tmpl);
|
||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||
|
||||
namespace minja {
|
||||
class chat_template;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||
std::unique_ptr<common_chat_template> template_tool_use;
|
||||
};
|
||||
|
||||
// CPP wrapper for llama_chat_apply_template
|
||||
// If the built-in template is not supported, we default to chatml
|
||||
// If the custom "tmpl" is not supported, we throw an error
|
||||
std::string common_chat_apply_template(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
std::string common_chat_apply_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & chat,
|
||||
bool add_ass);
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Format single message, while taking into account the position of that message in chat history
|
||||
std::string common_chat_format_single(const struct llama_model * model,
|
||||
const std::string & tmpl,
|
||||
std::string common_chat_format_single(
|
||||
const common_chat_template & tmpl,
|
||||
const std::vector<common_chat_msg> & past_msg,
|
||||
const common_chat_msg & new_msg,
|
||||
bool add_ass);
|
||||
bool add_ass,
|
||||
bool use_jinja);
|
||||
|
||||
// Returns an example of formatted chat
|
||||
std::string common_chat_format_example(const struct llama_model * model,
|
||||
const std::string & tmpl);
|
||||
std::string common_chat_format_example(
|
||||
const common_chat_template & tmpl, bool use_jinja);
|
||||
|
||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
||||
|
||||
//
|
||||
// KV cache utils
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
@@ -11,11 +13,6 @@
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
template <typename Iterator>
|
||||
static std::string join(Iterator begin, Iterator end, const std::string & separator);
|
||||
|
||||
static std::string repeat(const std::string & str, size_t n);
|
||||
|
||||
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
||||
auto has_max = max_items != std::numeric_limits<int>::max();
|
||||
|
||||
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
if (sub_len > 0) {
|
||||
auto from_sub = from.substr(i + 1);
|
||||
auto to_sub = to.substr(i + 1);
|
||||
auto sub_zeros = repeat("0", sub_len);
|
||||
auto sub_nines = repeat("9", sub_len);
|
||||
auto sub_zeros = string_repeat("0", sub_len);
|
||||
auto sub_nines = string_repeat("9", sub_len);
|
||||
|
||||
auto to_reached = false;
|
||||
out << "(";
|
||||
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
auto max_digits = max_s.length();
|
||||
|
||||
for (auto digits = min_digits; digits < max_digits; digits++) {
|
||||
uniform_range(min_s, repeat("9", digits));
|
||||
min_s = "1" + repeat("0", digits);
|
||||
uniform_range(min_s, string_repeat("9", digits));
|
||||
min_s = "1" + string_repeat("0", digits);
|
||||
out << " | ";
|
||||
}
|
||||
uniform_range(min_s, max_s);
|
||||
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
||||
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
||||
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
||||
|
||||
template <typename Iterator>
|
||||
std::string join(Iterator begin, Iterator end, const std::string & separator) {
|
||||
std::ostringstream result;
|
||||
if (begin != end) {
|
||||
result << *begin;
|
||||
for (Iterator it = begin + 1; it != end; ++it) {
|
||||
result << separator << *it;
|
||||
}
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
|
||||
std::vector<std::string> tokens;
|
||||
size_t start = 0;
|
||||
size_t end = str.find(delimiter);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
tokens.push_back(str.substr(start, end - start));
|
||||
start = end + delimiter.length();
|
||||
end = str.find(delimiter, start);
|
||||
}
|
||||
|
||||
tokens.push_back(str.substr(start));
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
static std::string repeat(const std::string & str, size_t n) {
|
||||
if (n == 0) {
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string result;
|
||||
result.reserve(str.length() * n);
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
result += str;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
||||
std::smatch match;
|
||||
std::string result;
|
||||
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
|
||||
|
||||
class SchemaConverter {
|
||||
private:
|
||||
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
std::function<json(const std::string &)> _fetch_json;
|
||||
bool _dotall;
|
||||
std::map<std::string, std::string> _rules;
|
||||
@@ -418,7 +373,7 @@ private:
|
||||
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
||||
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
||||
}
|
||||
return join(rules.begin(), rules.end(), " | ");
|
||||
return string_join(rules, " | ");
|
||||
}
|
||||
|
||||
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
||||
@@ -481,7 +436,7 @@ private:
|
||||
for (const auto & item : ret) {
|
||||
results.push_back(to_rule(item));
|
||||
}
|
||||
return std::make_pair(join(results.begin(), results.end(), " "), false);
|
||||
return std::make_pair(string_join(results, " "), false);
|
||||
};
|
||||
|
||||
while (i < length) {
|
||||
@@ -539,7 +494,7 @@ private:
|
||||
}
|
||||
curly_brackets += '}';
|
||||
i++;
|
||||
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
||||
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
||||
int min_times = 0;
|
||||
int max_times = std::numeric_limits<int>::max();
|
||||
try {
|
||||
@@ -854,7 +809,7 @@ public:
|
||||
return;
|
||||
}
|
||||
std::string pointer = ref.substr(ref.find('#') + 1);
|
||||
std::vector<std::string> tokens = split(pointer, "/");
|
||||
std::vector<std::string> tokens = string_split(pointer, "/");
|
||||
for (size_t i = 1; i < tokens.size(); ++i) {
|
||||
std::string sel = tokens[i];
|
||||
if (target.is_null() || !target.contains(sel)) {
|
||||
@@ -905,7 +860,7 @@ public:
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
} else if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
||||
@@ -1019,10 +974,10 @@ public:
|
||||
|
||||
void check_errors() {
|
||||
if (!_errors.empty()) {
|
||||
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
|
||||
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
||||
}
|
||||
if (!_warnings.empty()) {
|
||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
|
||||
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1036,10 +991,27 @@ public:
|
||||
};
|
||||
|
||||
std::string json_schema_to_grammar(const json & schema) {
|
||||
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
|
||||
auto copy = schema;
|
||||
converter.resolve_refs(copy, "input");
|
||||
converter.visit(copy, "");
|
||||
return build_grammar([&](const llama_grammar_builder & callbacks) {
|
||||
auto copy = schema;
|
||||
callbacks.resolve_refs(copy);
|
||||
callbacks.add_schema("", copy);
|
||||
});
|
||||
}
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
|
||||
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
|
||||
llama_grammar_builder builder {
|
||||
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
||||
return converter._add_rule(name, rule);
|
||||
},
|
||||
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
||||
return converter.visit(schema, name == "root" ? "" : name);
|
||||
},
|
||||
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
||||
converter.resolve_refs(schema, "");
|
||||
}
|
||||
};
|
||||
cb(builder);
|
||||
converter.check_errors();
|
||||
return converter.format_grammar();
|
||||
}
|
||||
|
||||
@@ -5,4 +5,12 @@
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
|
||||
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
|
||||
|
||||
struct llama_grammar_builder {
|
||||
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
||||
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
||||
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
||||
};
|
||||
|
||||
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
|
||||
|
||||
+2819
File diff suppressed because it is too large
Load Diff
@@ -133,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets.
|
||||
### Build image
|
||||
```sh
|
||||
# Using FP16
|
||||
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile .
|
||||
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
|
||||
```
|
||||
|
||||
*Notes*:
|
||||
|
||||
+1
-1
@@ -286,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container.
|
||||
|
||||
```sh
|
||||
# Build the image
|
||||
docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile .
|
||||
docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile .
|
||||
|
||||
# Then, use it:
|
||||
docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33
|
||||
|
||||
+6
-6
@@ -60,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia
|
||||
## Building Docker locally
|
||||
|
||||
```bash
|
||||
docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture.
|
||||
@@ -95,9 +95,9 @@ Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/
|
||||
## Building Docker locally
|
||||
|
||||
```bash
|
||||
docker build -t local/llama.cpp:full-musa -f .devops/full-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-musa -f .devops/llama-cli-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-musa -f .devops/llama-server-musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture.
|
||||
|
||||
@@ -345,8 +345,18 @@ struct lora_merge_ctx {
|
||||
gf = ggml_new_graph(ctx0);
|
||||
struct ggml_tensor * cur = inp_base;
|
||||
for (size_t i = 0; i < adapters.size(); ++i) {
|
||||
struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)));
|
||||
struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
|
||||
struct ggml_tensor * delta;
|
||||
bool is_tok_embd = string_starts_with(name_base, "token_embd");
|
||||
if (is_tok_embd) {
|
||||
printf("%s : detected token embeddings tensor\n", __func__);
|
||||
delta = ggml_mul_mat(ctx0,
|
||||
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
|
||||
ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
|
||||
} else {
|
||||
delta = ggml_mul_mat(ctx0,
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
|
||||
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
|
||||
}
|
||||
// scale
|
||||
const float alpha = adapters[i]->alpha;
|
||||
const float rank = (float) inp_b[i]->ne[0];
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
## MiniCPM-o 2.6
|
||||
Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone git@github.com:OpenBMB/llama.cpp.git
|
||||
cd llama.cpp
|
||||
git checkout minicpm-omni
|
||||
```
|
||||
|
||||
### Usage of MiniCPM-o 2.6
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
|
||||
|
||||
```bash
|
||||
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
|
||||
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
Build llama.cpp using `CMake`:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
|
||||
|
||||
```bash
|
||||
cmake -B build
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
Inference on Linux or Mac
|
||||
```
|
||||
# run f16 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# or run in interactive mode
|
||||
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||
```
|
||||
+27
-2
@@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||
}
|
||||
ggml_set_name(pos_embed, "pos_embed");
|
||||
ggml_set_input(pos_embed);
|
||||
}
|
||||
@@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
n_head = hidden_size/d_head;
|
||||
num_query = 64;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
hidden_size = 3584;
|
||||
n_head = hidden_size/d_head;
|
||||
num_query = 64;
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
@@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
|
||||
images[images.size()-1].push_back(patch);
|
||||
}
|
||||
}
|
||||
clip_image_u8_free(refine_image);
|
||||
}
|
||||
return images;
|
||||
}
|
||||
@@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
clip_image_f32_free(res);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t j = 0; j < imgs[i].size(); ++j) {
|
||||
if (imgs[i][j] != nullptr) {
|
||||
clip_image_u8_free(imgs[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
else if (ctx->has_qwen2vl_merger) {
|
||||
@@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
n_patches = 64;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
n_patches = 64;
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
int patch_size = params.patch_size * 2;
|
||||
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
|
||||
@@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
int bucket_coords_h[70];
|
||||
int bucket_coords_w[70];
|
||||
int bucket_coords_h[1024];
|
||||
int bucket_coords_w[1024];
|
||||
for (int i = 0; i < pos_h; i++){
|
||||
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
||||
}
|
||||
@@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
embed_dim = 3584;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
embed_dim = 3584;
|
||||
}
|
||||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||
|
||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||
@@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
return 3584;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 4) {
|
||||
return 3584;
|
||||
}
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
|
||||
@@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
|
||||
return true;
|
||||
}
|
||||
|
||||
static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
|
||||
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
|
||||
int width = image->nx;
|
||||
int height = image->ny;
|
||||
int num_patches = (height / patch_size) * (width / patch_size);
|
||||
@@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
else {
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
|
||||
if (!encoded) {
|
||||
@@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||
load_image_size->height = img->ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
|
||||
delete[] img_res_v.data;
|
||||
img_res_v.size = 0;
|
||||
img_res_v.data = nullptr;
|
||||
}
|
||||
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
|
||||
// flat / default llava-1.5 type embedding
|
||||
|
||||
@@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
system_prompt = "<|im_start|>user\n";
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
system_prompt = "<|im_start|>user\n";
|
||||
}
|
||||
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
||||
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
@@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
user_prompt = "<|im_start|>user\n" + prompt;
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
user_prompt = "<|im_start|>user\n" + prompt;
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
||||
@@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 4) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
|
||||
// generate the response
|
||||
|
||||
@@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
|
||||
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
printf("%s", tmp);// mistral llava-1.6
|
||||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||
fflush(stdout);
|
||||
|
||||
@@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
|
||||
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
|
||||
|
||||
# with proper
|
||||
args = ap.parse_args()
|
||||
@@ -545,12 +545,19 @@ if args.use_f32:
|
||||
|
||||
minicpmv_version = args.minicpmv_version
|
||||
emb_dim = 4096
|
||||
block_count = 26
|
||||
if minicpmv_version == 1:
|
||||
emb_dim = 2304
|
||||
block_count = 26
|
||||
elif minicpmv_version == 2:
|
||||
emb_dim = 4096
|
||||
block_count = 27
|
||||
elif minicpmv_version == 3:
|
||||
emb_dim = 3584
|
||||
block_count = 27
|
||||
elif minicpmv_version == 4:
|
||||
emb_dim = 3584
|
||||
block_count = 27
|
||||
|
||||
default_vision_config = {
|
||||
"hidden_size": 1152,
|
||||
@@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config)
|
||||
if minicpmv_version == 3:
|
||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||
model = SiglipVisionTransformer(vision_config)
|
||||
elif minicpmv_version == 4:
|
||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||
model = SiglipVisionTransformer(vision_config)
|
||||
|
||||
processor = None
|
||||
# if model.attn_pool is not None:
|
||||
@@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None:
|
||||
fname_middle = "mmproj-"
|
||||
has_text_encoder = False
|
||||
has_minicpmv_projector = True
|
||||
minicpmv_version = 3
|
||||
minicpmv_version = 4
|
||||
elif args.vision_only:
|
||||
fname_middle = "vision-"
|
||||
has_text_encoder = False
|
||||
@@ -625,7 +635,6 @@ if has_vision_encoder:
|
||||
fout.add_uint32("clip.vision.projection_dim", 0)
|
||||
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
|
||||
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
||||
block_count = 26
|
||||
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
|
||||
|
||||
if processor is not None:
|
||||
|
||||
@@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
|
||||
args = ap.parse_args()
|
||||
|
||||
# find the model part that includes the the multimodal projector weights
|
||||
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
|
||||
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||
checkpoint = model.state_dict()
|
||||
|
||||
# get a list of mm tensor names
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project("llama-cli-cmake-pkg" C CXX)
|
||||
set(TARGET llama-cli-cmake-pkg)
|
||||
|
||||
find_package(Llama 0.0.1 REQUIRED)
|
||||
|
||||
# Bake common functionality in with target. Because applications
|
||||
# using the relocatable Llama package should be outside of the
|
||||
# source tree, llama-cli-cmake-pkg pretends the dependencies are built-in.
|
||||
set(_common_path "${CMAKE_CURRENT_LIST_DIR}/../../common")
|
||||
add_library(common OBJECT)
|
||||
file(GLOB _common_files
|
||||
"${_common_path}/*.h"
|
||||
"${_common_path}/*.cpp"
|
||||
)
|
||||
target_sources(common PRIVATE ${_common_files})
|
||||
|
||||
# If the common project was part of "llama-cli-cmake-pkg" the transient
|
||||
# defines would automatically be attached. Because the common func-
|
||||
# tionality is separate, but dependent upon the defines, it must be
|
||||
# explicitly extracted from the "llama" target.
|
||||
#
|
||||
get_target_property(_llama_transient_defines llama
|
||||
INTERFACE_COMPILE_DEFINITIONS)
|
||||
|
||||
target_compile_definitions(common PRIVATE "${_llama_transient_defines}")
|
||||
|
||||
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp)
|
||||
target_include_directories(${TARGET} PRIVATE ${_common_path})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
@@ -1,31 +0,0 @@
|
||||
# llama.cpp/example/main-cmake-pkg
|
||||
|
||||
This program builds [llama-cli](../main) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
|
||||
|
||||
## Building
|
||||
|
||||
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
|
||||
|
||||
### Considerations
|
||||
|
||||
When hardware acceleration libraries are used (e.g. CUDA, Metal, etc.), CMake must be able to locate the associated CMake package.
|
||||
|
||||
### Build llama.cpp and install to C:\LlamaCPP directory
|
||||
|
||||
```cmd
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
cmake -B build -DBUILD_SHARED_LIBS=OFF -G "Visual Studio 17 2022" -A x64
|
||||
cmake --build build --config Release
|
||||
cmake --install build --prefix C:/LlamaCPP
|
||||
```
|
||||
|
||||
### Build llama-cli-cmake-pkg
|
||||
|
||||
|
||||
```cmd
|
||||
cd ..\examples\main-cmake-pkg
|
||||
cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64
|
||||
cmake --build build --config Release
|
||||
cmake --install build --prefix C:/MyLlamaApp
|
||||
```
|
||||
@@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models.
|
||||
|
||||
### Batch Size
|
||||
|
||||
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
|
||||
- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`.
|
||||
|
||||
- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`.
|
||||
- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`.
|
||||
|
||||
### Prompt Caching
|
||||
|
||||
|
||||
+15
-13
@@ -4,6 +4,7 @@
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "llama.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
@@ -84,14 +85,6 @@ static void sigint_handler(int signo) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg{role, content};
|
||||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
g_params = ¶ms;
|
||||
@@ -165,6 +158,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
|
||||
|
||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||
|
||||
@@ -207,7 +201,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// auto enable conversation mode if chat template is available
|
||||
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
|
||||
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
|
||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
||||
if (has_chat_template) {
|
||||
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
||||
@@ -225,7 +219,7 @@ int main(int argc, char ** argv) {
|
||||
// print chat template example in conversation mode
|
||||
if (params.conversation_mode) {
|
||||
if (params.enable_chat_template) {
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
|
||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
|
||||
} else {
|
||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||
}
|
||||
@@ -269,10 +263,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
std::vector<llama_token> embd_inp;
|
||||
|
||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||
common_chat_msg new_msg{role, content};
|
||||
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||
return formatted;
|
||||
};
|
||||
|
||||
{
|
||||
auto prompt = (params.conversation_mode && params.enable_chat_template)
|
||||
// format the system prompt in conversation mode (fallback to default if empty)
|
||||
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
||||
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
|
||||
// otherwise use the prompt as is
|
||||
: params.prompt;
|
||||
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
|
||||
@@ -779,7 +781,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (params.enable_chat_template) {
|
||||
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
|
||||
chat_add_and_format("assistant", assistant_ss.str());
|
||||
}
|
||||
is_interacting = true;
|
||||
LOG("\n");
|
||||
@@ -844,7 +846,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
bool format_chat = params.conversation_mode && params.enable_chat_template;
|
||||
std::string user_inp = format_chat
|
||||
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
|
||||
? chat_add_and_format("user", std::move(buffer))
|
||||
: std::move(buffer);
|
||||
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
|
||||
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models.
|
||||
|
||||
```bash
|
||||
llama-run granite-code
|
||||
llama-run granite3-moe
|
||||
```
|
||||
|
||||
```bash
|
||||
llama-run -h
|
||||
Description:
|
||||
Runs a llm
|
||||
|
||||
@@ -17,7 +16,7 @@ Usage:
|
||||
Options:
|
||||
-c, --context-size <value>
|
||||
Context size (default: 2048)
|
||||
-n, --ngl <value>
|
||||
-n, -ngl, --ngl <value>
|
||||
Number of GPU layers (default: 0)
|
||||
--temp <value>
|
||||
Temperature (default: 0.8)
|
||||
|
||||
@@ -103,24 +103,26 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#include <termios.h>
|
||||
#include <unistd.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <ctype.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
#include "linenoise.h"
|
||||
# include "linenoise.h"
|
||||
|
||||
#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100
|
||||
#define LINENOISE_MAX_LINE 4096
|
||||
static std::vector<const char*> unsupported_term = {"dumb","cons25","emacs",nullptr};
|
||||
# include <ctype.h>
|
||||
# include <errno.h>
|
||||
# include <stdio.h>
|
||||
# include <string.h>
|
||||
# include <sys/file.h>
|
||||
# include <sys/ioctl.h>
|
||||
# include <sys/stat.h>
|
||||
# include <sys/types.h>
|
||||
# include <termios.h>
|
||||
# include <unistd.h>
|
||||
|
||||
# include <memory>
|
||||
# include <string>
|
||||
# include <vector>
|
||||
|
||||
# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100
|
||||
# define LINENOISE_MAX_LINE 4096
|
||||
static std::vector<const char *> unsupported_term = { "dumb", "cons25", "emacs" };
|
||||
static linenoiseCompletionCallback *completionCallback = NULL;
|
||||
static linenoiseHintsCallback *hintsCallback = NULL;
|
||||
static linenoiseFreeHintsCallback *freeHintsCallback = NULL;
|
||||
@@ -166,21 +168,58 @@ int linenoiseHistoryAdd(const char *line);
|
||||
#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both.
|
||||
static void refreshLine(struct linenoiseState *l);
|
||||
|
||||
class File {
|
||||
public:
|
||||
FILE * file = nullptr;
|
||||
|
||||
FILE * open(const std::string & filename, const char * mode) {
|
||||
file = fopen(filename.c_str(), mode);
|
||||
|
||||
return file;
|
||||
}
|
||||
|
||||
int lock() {
|
||||
if (file) {
|
||||
fd = fileno(file);
|
||||
if (flock(fd, LOCK_EX | LOCK_NB) != 0) {
|
||||
fd = -1;
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
~File() {
|
||||
if (fd >= 0) {
|
||||
flock(fd, LOCK_UN);
|
||||
}
|
||||
|
||||
if (file) {
|
||||
fclose(file);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int fd = -1;
|
||||
};
|
||||
|
||||
__attribute__((format(printf, 1, 2)))
|
||||
/* Debugging function. */
|
||||
#if 0
|
||||
static void lndebug(const char *fmt, ...) {
|
||||
static FILE *lndebug_fp = NULL;
|
||||
if (lndebug_fp == NULL) {
|
||||
lndebug_fp = fopen("/tmp/lndebug.txt", "a");
|
||||
static File file;
|
||||
if (file.file == nullptr) {
|
||||
file.open("/tmp/lndebug.txt", "a");
|
||||
}
|
||||
|
||||
if (lndebug_fp != NULL) {
|
||||
if (file.file != nullptr) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
vfprintf(lndebug_fp, fmt, args);
|
||||
vfprintf(file.file, fmt, args);
|
||||
va_end(args);
|
||||
fflush(lndebug_fp);
|
||||
fflush(file.file);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -213,8 +252,11 @@ void linenoiseSetMultiLine(int ml) {
|
||||
static int isUnsupportedTerm(void) {
|
||||
char *term = getenv("TERM");
|
||||
if (term == NULL) return 0;
|
||||
for (int j = 0; unsupported_term[j]; ++j)
|
||||
if (!strcasecmp(term, unsupported_term[j])) return 1;
|
||||
for (size_t j = 0; j < unsupported_term.size(); ++j) {
|
||||
if (!strcasecmp(term, unsupported_term[j])) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -334,17 +376,6 @@ static void linenoiseBeep(void) {
|
||||
fflush(stderr);
|
||||
}
|
||||
|
||||
/* ============================== Completion ================================ */
|
||||
|
||||
/* Free a list of completion option populated by linenoiseAddCompletion(). */
|
||||
static void freeCompletions(linenoiseCompletions *lc) {
|
||||
size_t i;
|
||||
for (i = 0; i < lc->len; i++)
|
||||
free(lc->cvec[i]);
|
||||
if (lc->cvec != NULL)
|
||||
free(lc->cvec);
|
||||
}
|
||||
|
||||
/* Called by completeLine() and linenoiseShow() to render the current
|
||||
* edited line with the proposed completion. If the current completion table
|
||||
* is already available, it is passed as second argument, otherwise the
|
||||
@@ -353,9 +384,9 @@ static void freeCompletions(linenoiseCompletions *lc) {
|
||||
* Flags are the same as refreshLine*(), that is REFRESH_* macros. */
|
||||
static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) {
|
||||
/* Obtain the table of completions if the caller didn't provide one. */
|
||||
linenoiseCompletions ctable = { 0, NULL };
|
||||
linenoiseCompletions ctable;
|
||||
if (lc == NULL) {
|
||||
completionCallback(ls->buf,&ctable);
|
||||
completionCallback(ls->buf, &ctable);
|
||||
lc = &ctable;
|
||||
}
|
||||
|
||||
@@ -364,16 +395,17 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple
|
||||
struct linenoiseState saved = *ls;
|
||||
ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]);
|
||||
ls->buf = lc->cvec[ls->completion_idx];
|
||||
refreshLineWithFlags(ls,flags);
|
||||
refreshLineWithFlags(ls, flags);
|
||||
ls->len = saved.len;
|
||||
ls->pos = saved.pos;
|
||||
ls->buf = saved.buf;
|
||||
} else {
|
||||
refreshLineWithFlags(ls,flags);
|
||||
refreshLineWithFlags(ls, flags);
|
||||
}
|
||||
|
||||
/* Free the completions table if needed. */
|
||||
if (lc != &ctable) freeCompletions(&ctable);
|
||||
if (lc == &ctable) {
|
||||
ctable.to_free = false;
|
||||
}
|
||||
}
|
||||
|
||||
/* This is an helper function for linenoiseEdit*() and is called when the
|
||||
@@ -391,11 +423,11 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple
|
||||
* possible completions, and the caller should read for the next characters
|
||||
* from stdin. */
|
||||
static int completeLine(struct linenoiseState *ls, int keypressed) {
|
||||
linenoiseCompletions lc = { 0, NULL };
|
||||
linenoiseCompletions lc;
|
||||
int nwritten;
|
||||
char c = keypressed;
|
||||
|
||||
completionCallback(ls->buf,&lc);
|
||||
completionCallback(ls->buf, &lc);
|
||||
if (lc.len == 0) {
|
||||
linenoiseBeep();
|
||||
ls->in_completion = 0;
|
||||
@@ -406,7 +438,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
|
||||
ls->in_completion = 1;
|
||||
ls->completion_idx = 0;
|
||||
} else {
|
||||
ls->completion_idx = (ls->completion_idx+1) % (lc.len+1);
|
||||
ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1);
|
||||
if (ls->completion_idx == lc.len) linenoiseBeep();
|
||||
}
|
||||
c = 0;
|
||||
@@ -420,8 +452,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
|
||||
default:
|
||||
/* Update buffer and return */
|
||||
if (ls->completion_idx < lc.len) {
|
||||
nwritten = snprintf(ls->buf,ls->buflen,"%s",
|
||||
lc.cvec[ls->completion_idx]);
|
||||
nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]);
|
||||
ls->len = ls->pos = nwritten;
|
||||
}
|
||||
ls->in_completion = 0;
|
||||
@@ -430,13 +461,12 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
|
||||
|
||||
/* Show completion or original buffer */
|
||||
if (ls->in_completion && ls->completion_idx < lc.len) {
|
||||
refreshLineWithCompletion(ls,&lc,REFRESH_ALL);
|
||||
refreshLineWithCompletion(ls, &lc, REFRESH_ALL);
|
||||
} else {
|
||||
refreshLine(ls);
|
||||
}
|
||||
}
|
||||
|
||||
freeCompletions(&lc);
|
||||
return c; /* Return last read character */
|
||||
}
|
||||
|
||||
@@ -462,53 +492,25 @@ void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) {
|
||||
* user typed <tab>. See the example.c source code for a very easy to
|
||||
* understand example. */
|
||||
void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) {
|
||||
size_t len = strlen(str);
|
||||
char *copy, **cvec;
|
||||
|
||||
copy = (char*) malloc(len + 1);
|
||||
if (copy == NULL) return;
|
||||
memcpy(copy,str,len+1);
|
||||
cvec = (char**) realloc(lc->cvec,sizeof(char*)*(lc->len+1));
|
||||
if (cvec == NULL) {
|
||||
free(copy);
|
||||
const size_t len = strlen(str);
|
||||
auto copy = std::make_unique<char[]>(len + 1);
|
||||
if (!copy) {
|
||||
return;
|
||||
}
|
||||
|
||||
memcpy(copy.get(), str, len + 1);
|
||||
char ** cvec = static_cast<char **>(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1)));
|
||||
if (cvec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
lc->cvec = cvec;
|
||||
lc->cvec[lc->len++] = copy;
|
||||
}
|
||||
|
||||
/* =========================== Line editing ================================= */
|
||||
|
||||
/* We define a very simple "append buffer" structure, that is an heap
|
||||
* allocated string where we can append to. This is useful in order to
|
||||
* write all the escape sequences in a buffer and flush them to the standard
|
||||
* output in a single call, to avoid flickering effects. */
|
||||
struct abuf {
|
||||
char *b;
|
||||
int len;
|
||||
};
|
||||
|
||||
static void abInit(struct abuf *ab) {
|
||||
ab->b = NULL;
|
||||
ab->len = 0;
|
||||
}
|
||||
|
||||
static void abAppend(struct abuf *ab, const char *s, int len) {
|
||||
char *new_ptr = (char*) realloc(ab->b,ab->len+len);
|
||||
|
||||
if (new_ptr == NULL) return;
|
||||
memcpy(new_ptr+ab->len,s,len);
|
||||
ab->b = new_ptr;
|
||||
ab->len += len;
|
||||
}
|
||||
|
||||
static void abFree(struct abuf *ab) {
|
||||
free(ab->b);
|
||||
lc->cvec[lc->len++] = copy.release();
|
||||
}
|
||||
|
||||
/* Helper of refreshSingleLine() and refreshMultiLine() to show hints
|
||||
* to the right of the prompt. */
|
||||
static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int plen) {
|
||||
static void refreshShowHints(std::string & ab, struct linenoiseState * l, int plen) {
|
||||
char seq[64];
|
||||
if (hintsCallback && plen+l->len < l->cols) {
|
||||
int color = -1, bold = 0;
|
||||
@@ -522,10 +524,11 @@ static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int pl
|
||||
snprintf(seq,64,"\033[%d;%d;49m",bold,color);
|
||||
else
|
||||
seq[0] = '\0';
|
||||
abAppend(ab,seq,strlen(seq));
|
||||
abAppend(ab,hint,hintlen);
|
||||
ab.append(seq);
|
||||
ab.append(hint, hintlen);
|
||||
if (color != -1 || bold != 0)
|
||||
abAppend(ab,"\033[0m",4);
|
||||
ab.append("\033[0m");
|
||||
|
||||
/* Call the function to free the hint returned. */
|
||||
if (freeHintsCallback) freeHintsCallback(hint);
|
||||
}
|
||||
@@ -546,8 +549,7 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) {
|
||||
char *buf = l->buf;
|
||||
size_t len = l->len;
|
||||
size_t pos = l->pos;
|
||||
struct abuf ab;
|
||||
|
||||
std::string ab;
|
||||
while((plen+pos) >= l->cols) {
|
||||
buf++;
|
||||
len--;
|
||||
@@ -557,35 +559,34 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) {
|
||||
len--;
|
||||
}
|
||||
|
||||
abInit(&ab);
|
||||
/* Cursor to left edge */
|
||||
snprintf(seq,sizeof(seq),"\r");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
|
||||
if (flags & REFRESH_WRITE) {
|
||||
/* Write the prompt and the current buffer content */
|
||||
abAppend(&ab,l->prompt,strlen(l->prompt));
|
||||
ab.append(l->prompt);
|
||||
if (maskmode == 1) {
|
||||
while (len--) abAppend(&ab,"*",1);
|
||||
while (len--) {
|
||||
ab.append("*");
|
||||
}
|
||||
} else {
|
||||
abAppend(&ab,buf,len);
|
||||
ab.append(buf, len);
|
||||
}
|
||||
/* Show hits if any. */
|
||||
refreshShowHints(&ab,l,plen);
|
||||
refreshShowHints(ab, l, plen);
|
||||
}
|
||||
|
||||
/* Erase to right */
|
||||
snprintf(seq,sizeof(seq),"\x1b[0K");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
|
||||
ab.append(seq);
|
||||
if (flags & REFRESH_WRITE) {
|
||||
/* Move cursor to original position. */
|
||||
snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen));
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
|
||||
if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */
|
||||
abFree(&ab);
|
||||
(void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */
|
||||
}
|
||||
|
||||
/* Multi line low level line refresh.
|
||||
@@ -604,26 +605,23 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
|
||||
int col; /* colum position, zero-based. */
|
||||
int old_rows = l->oldrows;
|
||||
int fd = l->ofd, j;
|
||||
struct abuf ab;
|
||||
|
||||
std::string ab;
|
||||
l->oldrows = rows;
|
||||
|
||||
/* First step: clear all the lines used before. To do so start by
|
||||
* going to the last row. */
|
||||
abInit(&ab);
|
||||
|
||||
if (flags & REFRESH_CLEAN) {
|
||||
if (old_rows-rpos > 0) {
|
||||
lndebug("go down %d", old_rows-rpos);
|
||||
snprintf(seq,64,"\x1b[%dB", old_rows-rpos);
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
|
||||
/* Now for every row clear it, go up. */
|
||||
for (j = 0; j < old_rows-1; j++) {
|
||||
lndebug("clear+up");
|
||||
snprintf(seq,64,"\r\x1b[0K\x1b[1A");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,21 +629,22 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
|
||||
/* Clean the top line. */
|
||||
lndebug("clear");
|
||||
snprintf(seq,64,"\r\x1b[0K");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
|
||||
if (flags & REFRESH_WRITE) {
|
||||
/* Write the prompt and the current buffer content */
|
||||
abAppend(&ab,l->prompt,strlen(l->prompt));
|
||||
ab.append(l->prompt);
|
||||
if (maskmode == 1) {
|
||||
unsigned int i;
|
||||
for (i = 0; i < l->len; i++) abAppend(&ab,"*",1);
|
||||
for (unsigned int i = 0; i < l->len; ++i) {
|
||||
ab.append("*");
|
||||
}
|
||||
} else {
|
||||
abAppend(&ab,l->buf,l->len);
|
||||
ab.append(l->buf, l->len);
|
||||
}
|
||||
|
||||
/* Show hits if any. */
|
||||
refreshShowHints(&ab,l,plen);
|
||||
refreshShowHints(ab, l, plen);
|
||||
|
||||
/* If we are at the very end of the screen with our prompt, we need to
|
||||
* emit a newline and move the prompt to the first column. */
|
||||
@@ -654,9 +653,9 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
|
||||
(l->pos+plen) % l->cols == 0)
|
||||
{
|
||||
lndebug("<newline>");
|
||||
abAppend(&ab,"\n",1);
|
||||
ab.append("\n");
|
||||
snprintf(seq,64,"\r");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
rows++;
|
||||
if (rows > (int)l->oldrows) l->oldrows = rows;
|
||||
}
|
||||
@@ -669,7 +668,7 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
|
||||
if (rows-rpos2 > 0) {
|
||||
lndebug("go-up %d", rows-rpos2);
|
||||
snprintf(seq,64,"\x1b[%dA", rows-rpos2);
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
|
||||
/* Set column. */
|
||||
@@ -679,14 +678,12 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
|
||||
snprintf(seq,64,"\r\x1b[%dC", col);
|
||||
else
|
||||
snprintf(seq,64,"\r");
|
||||
abAppend(&ab,seq,strlen(seq));
|
||||
ab.append(seq);
|
||||
}
|
||||
|
||||
lndebug("\n");
|
||||
l->oldpos = l->pos;
|
||||
|
||||
if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */
|
||||
abFree(&ab);
|
||||
(void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */
|
||||
}
|
||||
|
||||
/* Calls the two low level functions refreshSingleLine() or
|
||||
@@ -1313,16 +1310,17 @@ int linenoiseHistorySetMaxLen(int len) {
|
||||
* otherwise -1 is returned. */
|
||||
int linenoiseHistorySave(const char *filename) {
|
||||
mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO);
|
||||
FILE *fp;
|
||||
int j;
|
||||
|
||||
fp = fopen(filename,"w");
|
||||
File file;
|
||||
file.open(filename, "w");
|
||||
umask(old_umask);
|
||||
if (fp == NULL) return -1;
|
||||
if (file.file == NULL) {
|
||||
return -1;
|
||||
}
|
||||
chmod(filename,S_IRUSR|S_IWUSR);
|
||||
for (j = 0; j < history_len; j++)
|
||||
fprintf(fp,"%s\n",history[j]);
|
||||
fclose(fp);
|
||||
for (int j = 0; j < history_len; ++j) {
|
||||
fprintf(file.file, "%s\n", history[j]);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1332,12 +1330,14 @@ int linenoiseHistorySave(const char *filename) {
|
||||
* If the file exists and the operation succeeded 0 is returned, otherwise
|
||||
* on error -1 is returned. */
|
||||
int linenoiseHistoryLoad(const char *filename) {
|
||||
FILE *fp = fopen(filename,"r");
|
||||
File file;
|
||||
file.open(filename, "r");
|
||||
char buf[LINENOISE_MAX_LINE];
|
||||
if (file.file == NULL) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (fp == NULL) return -1;
|
||||
|
||||
while (fgets(buf,LINENOISE_MAX_LINE,fp) != NULL) {
|
||||
while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) {
|
||||
char *p;
|
||||
|
||||
p = strchr(buf,'\r');
|
||||
@@ -1345,7 +1345,6 @@ int linenoiseHistoryLoad(const char *filename) {
|
||||
if (p) *p = '\0';
|
||||
linenoiseHistoryAdd(buf);
|
||||
}
|
||||
fclose(fp);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -45,6 +45,7 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
#include <stddef.h> /* For size_t. */
|
||||
#include <stdlib.h>
|
||||
|
||||
extern const char *linenoiseEditMore;
|
||||
|
||||
@@ -69,10 +70,23 @@ struct linenoiseState {
|
||||
int history_index; /* The history index we are currently editing. */
|
||||
};
|
||||
|
||||
typedef struct linenoiseCompletions {
|
||||
size_t len;
|
||||
char **cvec;
|
||||
} linenoiseCompletions;
|
||||
struct linenoiseCompletions {
|
||||
size_t len = 0;
|
||||
char ** cvec = nullptr;
|
||||
bool to_free = true;
|
||||
|
||||
~linenoiseCompletions() {
|
||||
if (!to_free) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
free(cvec[i]);
|
||||
}
|
||||
|
||||
free(cvec);
|
||||
}
|
||||
};
|
||||
|
||||
/* Non blocking API. */
|
||||
int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);
|
||||
|
||||
+172
-69
@@ -28,6 +28,7 @@
|
||||
#include "json.hpp"
|
||||
#include "linenoise.cpp/linenoise.h"
|
||||
#include "llama-cpp.h"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
|
||||
[[noreturn]] static void sigint_handler(int) {
|
||||
@@ -105,6 +106,7 @@ class Opt {
|
||||
llama_model_params model_params;
|
||||
std::string model_;
|
||||
std::string user;
|
||||
bool use_jinja = false;
|
||||
int context_size = -1, ngl = -1;
|
||||
float temperature = -1;
|
||||
bool verbose = false;
|
||||
@@ -145,7 +147,8 @@ class Opt {
|
||||
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
|
||||
return 1;
|
||||
}
|
||||
} else if (options_parsing && (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
||||
} else if (options_parsing &&
|
||||
(strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) {
|
||||
if (handle_option_with_value(argc, argv, i, ngl) == 1) {
|
||||
return 1;
|
||||
}
|
||||
@@ -156,6 +159,8 @@ class Opt {
|
||||
} else if (options_parsing &&
|
||||
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
|
||||
verbose = true;
|
||||
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||
use_jinja = true;
|
||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||
help = true;
|
||||
return 0;
|
||||
@@ -176,6 +181,10 @@ class Opt {
|
||||
}
|
||||
}
|
||||
|
||||
if (model_.empty()){
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -190,7 +199,7 @@ class Opt {
|
||||
"Options:\n"
|
||||
" -c, --context-size <value>\n"
|
||||
" Context size (default: %d)\n"
|
||||
" -n, --ngl <value>\n"
|
||||
" -n, -ngl, --ngl <value>\n"
|
||||
" Number of GPU layers (default: %d)\n"
|
||||
" --temp <value>\n"
|
||||
" Temperature (default: %.1f)\n"
|
||||
@@ -314,6 +323,10 @@ class HttpClient {
|
||||
public:
|
||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
const bool progress, std::string * response_str = nullptr) {
|
||||
if (std::filesystem::exists(output_file)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string output_file_partial;
|
||||
curl = curl_easy_init();
|
||||
if (!curl) {
|
||||
@@ -341,7 +354,11 @@ class HttpClient {
|
||||
data.file_size = set_resume_point(output_file_partial);
|
||||
set_progress_options(progress, data);
|
||||
set_headers(headers);
|
||||
perform(url);
|
||||
CURLcode res = perform(url);
|
||||
if (res != CURLE_OK){
|
||||
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
|
||||
return 1;
|
||||
}
|
||||
if (!output_file.empty()) {
|
||||
std::filesystem::rename(output_file_partial, output_file);
|
||||
}
|
||||
@@ -406,16 +423,12 @@ class HttpClient {
|
||||
}
|
||||
}
|
||||
|
||||
void perform(const std::string & url) {
|
||||
CURLcode res;
|
||||
CURLcode perform(const std::string & url) {
|
||||
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
|
||||
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
|
||||
res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res));
|
||||
}
|
||||
return curl_easy_perform(curl);
|
||||
}
|
||||
|
||||
static std::string human_readable_time(double seconds) {
|
||||
@@ -553,13 +566,14 @@ class LlamaData {
|
||||
}
|
||||
|
||||
sampler = initialize_sampler(opt);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef LLAMA_USE_CURL
|
||||
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
const bool progress, std::string * response_str = nullptr) {
|
||||
int download(const std::string & url, const std::string & output_file, const bool progress,
|
||||
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
|
||||
HttpClient http;
|
||||
if (http.init(url, headers, output_file, progress, response_str)) {
|
||||
return 1;
|
||||
@@ -568,48 +582,85 @@ class LlamaData {
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
|
||||
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
|
||||
std::string * = nullptr) {
|
||||
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
||||
|
||||
return 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
if (pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string hfr = model.substr(0, pos);
|
||||
const std::string hff = model.substr(pos + 1);
|
||||
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
return download(url, headers, bn, true);
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
|
||||
if (model.find('/') == std::string::npos) {
|
||||
model = "library/" + model;
|
||||
}
|
||||
|
||||
std::string model_tag = "latest";
|
||||
size_t colon_pos = model.find(':');
|
||||
// Helper function to handle model tag extraction and URL construction
|
||||
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
|
||||
std::string model_tag = "latest";
|
||||
const size_t colon_pos = model.find(':');
|
||||
if (colon_pos != std::string::npos) {
|
||||
model_tag = model.substr(colon_pos + 1);
|
||||
model = model.substr(0, colon_pos);
|
||||
}
|
||||
|
||||
std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
|
||||
std::string url = base_url + model + "/manifests/" + model_tag;
|
||||
|
||||
return { model, url };
|
||||
}
|
||||
|
||||
// Helper function to download and parse the manifest
|
||||
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
|
||||
nlohmann::json & manifest) {
|
||||
std::string manifest_str;
|
||||
const int ret = download(manifest_url, headers, "", false, &manifest_str);
|
||||
int ret = download(url, "", false, headers, &manifest_str);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
|
||||
std::string layer;
|
||||
manifest = nlohmann::json::parse(manifest_str);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int huggingface_dl(std::string & model, const std::string & bn) {
|
||||
// Find the second occurrence of '/' after protocol string
|
||||
size_t pos = model.find('/');
|
||||
pos = model.find('/', pos + 1);
|
||||
std::string hfr, hff;
|
||||
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
|
||||
std::string url;
|
||||
|
||||
if (pos == std::string::npos) {
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
|
||||
hfr = model_name;
|
||||
|
||||
nlohmann::json manifest;
|
||||
int ret = download_and_parse_manifest(manifest_url, headers, manifest);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
hff = manifest["ggufFile"]["rfilename"];
|
||||
} else {
|
||||
hfr = model.substr(0, pos);
|
||||
hff = model.substr(pos + 1);
|
||||
}
|
||||
|
||||
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
|
||||
|
||||
return download(url, bn, true, headers);
|
||||
}
|
||||
|
||||
int ollama_dl(std::string & model, const std::string & bn) {
|
||||
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (model.find('/') == std::string::npos) {
|
||||
model = "library/" + model;
|
||||
}
|
||||
|
||||
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
|
||||
nlohmann::json manifest;
|
||||
int ret = download_and_parse_manifest(manifest_url, {}, manifest);
|
||||
if (ret) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string layer;
|
||||
for (const auto & l : manifest["layers"]) {
|
||||
if (l["mediaType"] == "application/vnd.ollama.image.model") {
|
||||
layer = l["digest"];
|
||||
@@ -617,8 +668,34 @@ class LlamaData {
|
||||
}
|
||||
}
|
||||
|
||||
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
|
||||
return download(blob_url, headers, bn, true);
|
||||
std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;
|
||||
|
||||
return download(blob_url, bn, true, headers);
|
||||
}
|
||||
|
||||
int github_dl(const std::string & model, const std::string & bn) {
|
||||
std::string repository = model;
|
||||
std::string branch = "main";
|
||||
const size_t at_pos = model.find('@');
|
||||
if (at_pos != std::string::npos) {
|
||||
repository = model.substr(0, at_pos);
|
||||
branch = model.substr(at_pos + 1);
|
||||
}
|
||||
|
||||
const std::vector<std::string> repo_parts = string_split(repository, "/");
|
||||
if (repo_parts.size() < 3) {
|
||||
printe("Invalid GitHub repository format\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
const std::string & org = repo_parts[0];
|
||||
const std::string & project = repo_parts[1];
|
||||
std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch;
|
||||
for (size_t i = 2; i < repo_parts.size(); ++i) {
|
||||
url += "/" + repo_parts[i];
|
||||
}
|
||||
|
||||
return download(url, bn, true);
|
||||
}
|
||||
|
||||
std::string basename(const std::string & path) {
|
||||
@@ -630,37 +707,41 @@ class LlamaData {
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
|
||||
int remove_proto(std::string & model_) {
|
||||
const std::string::size_type pos = model_.find("://");
|
||||
int rm_until_substring(std::string & model_, const std::string & substring) {
|
||||
const std::string::size_type pos = model_.find(substring);
|
||||
if (pos == std::string::npos) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
model_ = model_.substr(pos + 3); // Skip past "://"
|
||||
model_ = model_.substr(pos + substring.size()); // Skip past the substring
|
||||
return 0;
|
||||
}
|
||||
|
||||
int resolve_model(std::string & model_) {
|
||||
int ret = 0;
|
||||
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
|
||||
remove_proto(model_);
|
||||
rm_until_substring(model_, "://");
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
const std::string bn = basename(model_);
|
||||
const std::vector<std::string> headers = { "--header",
|
||||
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
|
||||
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
|
||||
remove_proto(model_);
|
||||
ret = huggingface_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "ollama://")) {
|
||||
remove_proto(model_);
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
} else if (string_starts_with(model_, "https://")) {
|
||||
download(model_, headers, bn, true);
|
||||
} else {
|
||||
ret = ollama_dl(model_, headers, bn);
|
||||
const std::string bn = basename(model_);
|
||||
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") ||
|
||||
string_starts_with(model_, "hf.co/")) {
|
||||
rm_until_substring(model_, "hf.co/");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = huggingface_dl(model_, bn);
|
||||
} else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) &&
|
||||
!string_starts_with(model_, "https://ollama.com/library/")) {
|
||||
ret = download(model_, bn, true);
|
||||
} else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) {
|
||||
rm_until_substring(model_, "github:");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = github_dl(model_, bn);
|
||||
} else { // ollama:// or nothing
|
||||
rm_until_substring(model_, "ollama.com/library/");
|
||||
rm_until_substring(model_, "://");
|
||||
ret = ollama_dl(model_, bn);
|
||||
}
|
||||
|
||||
model_ = bn;
|
||||
@@ -713,13 +794,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
||||
}
|
||||
|
||||
// Function to apply the chat template and resize `formatted` if needed
|
||||
static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : llama_data.messages) {
|
||||
messages.push_back({
|
||||
{"role", msg.role},
|
||||
{"content", msg.content},
|
||||
});
|
||||
}
|
||||
try {
|
||||
auto result = tmpl.apply(messages, /* tools= */ json(), append);
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
} catch (const std::exception & e) {
|
||||
printe("failed to render the chat template: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
int result = llama_chat_apply_template(
|
||||
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||
llama_data.fmtted.resize(result);
|
||||
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
|
||||
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||
llama_data.fmtted.size());
|
||||
}
|
||||
@@ -729,10 +828,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
|
||||
|
||||
// Function to tokenize the prompt
|
||||
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
|
||||
std::vector<llama_token> & prompt_tokens) {
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
|
||||
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
|
||||
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
|
||||
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
prompt_tokens.resize(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
|
||||
true) < 0) {
|
||||
printe("failed to tokenize the prompt\n");
|
||||
return -1;
|
||||
@@ -778,7 +879,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
|
||||
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
|
||||
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -869,8 +970,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||
}
|
||||
|
||||
// Helper function to apply the chat template and handle errors
|
||||
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
|
||||
const int new_len = apply_chat_template(llama_data, append);
|
||||
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||
if (new_len < 0) {
|
||||
printe("failed to apply the chat template\n");
|
||||
return -1;
|
||||
@@ -929,9 +1030,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||
}
|
||||
|
||||
// Main chat loop function
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||
int prev_len = 0;
|
||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
|
||||
GGML_ASSERT(chat_templates.template_default);
|
||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||
while (true) {
|
||||
// Get user input
|
||||
@@ -942,7 +1045,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||
|
||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -957,7 +1060,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@@ -1017,7 +1120,7 @@ int main(int argc, const char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (chat_loop(llama_data, opt.user)) {
|
||||
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
|
||||
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
|
||||
| `--grammar-file FNAME` | file to read grammar from |
|
||||
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
|
||||
|
||||
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
|
||||
|
||||
**Example-specific params**
|
||||
|
||||
@@ -236,9 +236,13 @@ npm i
|
||||
# to run the dev server
|
||||
npm run dev
|
||||
|
||||
# to build the public/index.html
|
||||
# to build the public/index.html.gz
|
||||
npm run build
|
||||
```
|
||||
After `public/index.html.gz` has been generated we need to generate the c++
|
||||
headers (like build/examples/server/index.html.gz.hpp) that will be included
|
||||
by server.cpp. This is done by building `llama-server` as described in the
|
||||
[build](#build) section above.
|
||||
|
||||
NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console:
|
||||
|
||||
@@ -456,7 +460,7 @@ These words will not be included in the completion, so make sure to add them to
|
||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||
|
||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
|
||||
```json
|
||||
```
|
||||
{
|
||||
"content": "<the generated completion text>",
|
||||
"tokens": [ generated token ids if requested ],
|
||||
@@ -557,7 +561,7 @@ If `with_pieces` is `true`:
|
||||
```
|
||||
|
||||
With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k
|
||||
```json
|
||||
```
|
||||
{
|
||||
"tokens": [
|
||||
{"id": 198, "piece": [195]}, // hex C3
|
||||
@@ -572,6 +576,18 @@ With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k
|
||||
|
||||
`tokens`: Set the tokens to detokenize.
|
||||
|
||||
### POST `/apply-template`: Apply chat template to a conversation
|
||||
|
||||
Uses the server's prompt template formatting functionality to convert chat messages to a single string expected by a chat model as input, but does not perform inference. Instead, the prompt string is returned in the `prompt` field of the JSON response. The prompt can then be modified as desired (for example, to insert "Sure!" at the beginning of the model's response) before sending to `/completion` to generate the chat response.
|
||||
|
||||
*Options:*
|
||||
|
||||
`messages`: (Required) Chat turns in the same format as `/v1/chat/completions`.
|
||||
|
||||
**Response format**
|
||||
|
||||
Returns a JSON object with a field `prompt` containing a string of the input messages formatted according to the model's chat template format.
|
||||
|
||||
### POST `/embedding`: Generate embedding of a given text
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -764,7 +780,7 @@ Same as the `/v1/embeddings` endpoint.
|
||||
|
||||
**Response format**
|
||||
|
||||
```json
|
||||
```
|
||||
[
|
||||
{
|
||||
"index": 0,
|
||||
|
||||
Binary file not shown.
+105
-23
@@ -14,7 +14,7 @@
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
// auto generated files (see README.md for details)
|
||||
#include "index.html.gz.hpp"
|
||||
#include "loading.html.hpp"
|
||||
|
||||
@@ -267,6 +267,11 @@ struct server_task {
|
||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
// Use OpenAI API logprobs only if n_probs wasn't provided
|
||||
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
|
||||
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
|
||||
}
|
||||
|
||||
if (data.contains("lora")) {
|
||||
if (data.at("lora").is_array()) {
|
||||
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
|
||||
@@ -1422,6 +1427,10 @@ struct server_queue {
|
||||
int post(server_task task, bool front = false) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
GGML_ASSERT(task.id != -1);
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
@@ -1439,6 +1448,10 @@ struct server_queue {
|
||||
if (task.id == -1) {
|
||||
task.id = id++;
|
||||
}
|
||||
// if this is cancel task make sure to clean up pending tasks
|
||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||
cleanup_pending_task(task.id_target);
|
||||
}
|
||||
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
|
||||
if (front) {
|
||||
queue_tasks.push_front(std::move(task));
|
||||
@@ -1539,6 +1552,20 @@ struct server_queue {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void cleanup_pending_task(int id_target) {
|
||||
// no need lock because this is called exclusively by post()
|
||||
auto rm_func = [id_target](const server_task & task) {
|
||||
return task.id_target == id_target;
|
||||
};
|
||||
queue_tasks.erase(
|
||||
std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
|
||||
queue_tasks.end());
|
||||
queue_tasks_deferred.erase(
|
||||
std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
|
||||
queue_tasks_deferred.end());
|
||||
}
|
||||
};
|
||||
|
||||
struct server_response {
|
||||
@@ -1574,6 +1601,12 @@ struct server_response {
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
waiting_task_ids.erase(id_task);
|
||||
// make sure to clean up all pending results
|
||||
queue_results.erase(
|
||||
std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
|
||||
return res->id == id_task;
|
||||
}),
|
||||
queue_results.end());
|
||||
}
|
||||
|
||||
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
||||
@@ -1593,7 +1626,7 @@ struct server_response {
|
||||
return !queue_results.empty();
|
||||
});
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
for (size_t i = 0; i < queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
server_task_result_ptr res = std::move(queue_results[i]);
|
||||
queue_results.erase(queue_results.begin() + i);
|
||||
@@ -1610,12 +1643,6 @@ struct server_response {
|
||||
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_results);
|
||||
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
|
||||
return !queue_results.empty();
|
||||
});
|
||||
if (!cr_res) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||
@@ -1624,6 +1651,11 @@ struct server_response {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
|
||||
if (cr_res == std::cv_status::timeout) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
@@ -1688,6 +1720,8 @@ struct server_context {
|
||||
// Necessary similarity of prompt for slot selection
|
||||
float slot_prompt_similarity = 0.0f;
|
||||
|
||||
common_chat_templates chat_templates;
|
||||
|
||||
~server_context() {
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
@@ -1728,13 +1762,16 @@ struct server_context {
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!params_base.speculative.model.empty()) {
|
||||
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
||||
|
||||
auto params_dft = params_base;
|
||||
|
||||
params_dft.devices = params_base.speculative.devices;
|
||||
params_dft.hf_file = params_base.speculative.hf_file;
|
||||
params_dft.hf_repo = params_base.speculative.hf_repo;
|
||||
params_dft.model = params_base.speculative.model;
|
||||
params_dft.model_url = params_base.speculative.model_url;
|
||||
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
||||
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
||||
params_dft.n_parallel = 1;
|
||||
@@ -1762,16 +1799,44 @@ struct server_context {
|
||||
// force F16 KV cache for the draft model for extra performance
|
||||
cparams_dft.type_k = GGML_TYPE_F16;
|
||||
cparams_dft.type_v = GGML_TYPE_F16;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft.context.reset();
|
||||
}
|
||||
|
||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
||||
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool validate_builtin_chat_template() const {
|
||||
bool validate_builtin_chat_template(bool use_jinja) const {
|
||||
llama_chat_message chat[] = {{"user", "test"}};
|
||||
const char * tmpl = llama_model_chat_template(model);
|
||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
|
||||
if (use_jinja) {
|
||||
auto templates = common_chat_templates_from_model(model, "");
|
||||
GGML_ASSERT(templates.template_default);
|
||||
try {
|
||||
templates.template_default->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
if (templates.template_tool_use) {
|
||||
templates.template_tool_use->apply({{
|
||||
{"role", "user"},
|
||||
{"content", "test"},
|
||||
}}, json(), true);
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("failed to apply template: %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
||||
return chat_res > 0;
|
||||
}
|
||||
}
|
||||
|
||||
void init() {
|
||||
@@ -2338,8 +2403,8 @@ struct server_context {
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||
task.id_target = id_task;
|
||||
cancel_tasks.push_back(task);
|
||||
queue_results.remove_waiting_task_id(id_task);
|
||||
cancel_tasks.push_back(task);
|
||||
}
|
||||
// push to beginning of the queue, so it has highest priority
|
||||
queue_tasks.post(cancel_tasks, true);
|
||||
@@ -3656,9 +3721,12 @@ int main(int argc, char ** argv) {
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||
{ "model_path", ctx_server.params_base.model },
|
||||
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
|
||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
||||
{ "build_info", build_info },
|
||||
};
|
||||
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
||||
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
|
||||
}
|
||||
|
||||
res_ok(res, data);
|
||||
};
|
||||
@@ -3886,7 +3954,10 @@ int main(int argc, char ** argv) {
|
||||
return;
|
||||
}
|
||||
|
||||
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
|
||||
return handle_completions_impl(
|
||||
SERVER_TASK_TYPE_COMPLETION,
|
||||
data,
|
||||
@@ -4053,6 +4124,14 @@ int main(int argc, char ** argv) {
|
||||
res_ok(res, root);
|
||||
};
|
||||
|
||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||
auto body = json::parse(req.body);
|
||||
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
|
||||
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
|
||||
|
||||
res_ok(res, {{ "prompt", data.at("prompt") }});
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||
};
|
||||
@@ -4229,6 +4308,7 @@ int main(int argc, char ** argv) {
|
||||
svr->Post("/v1/reranking", handle_rerank);
|
||||
svr->Post("/tokenize", handle_tokenize);
|
||||
svr->Post("/detokenize", handle_detokenize);
|
||||
svr->Post("/apply-template", handle_apply_template);
|
||||
// LoRA adapters hotswap
|
||||
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
||||
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
||||
@@ -4296,7 +4376,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_builtin_chat_template()) {
|
||||
if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
|
||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
@@ -4304,14 +4384,16 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
|
||||
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||
ctx_server.chat_templates.template_default->source().c_str(),
|
||||
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||
ctx_server.process_single_task(task);
|
||||
});
|
||||
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
||||
ctx_server.update_slots();
|
||||
});
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
|
||||
@@ -4,22 +4,26 @@ from utils import *
|
||||
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||
[
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||
global server
|
||||
server.jinja = jinja
|
||||
server.chat_template = chat_template
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
@@ -117,6 +121,21 @@ def test_chat_template():
|
||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
|
||||
|
||||
def test_apply_chat_template():
|
||||
global server
|
||||
server.chat_template = "command-r"
|
||||
server.start()
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a test."},
|
||||
{"role": "user", "content":"Hi there"},
|
||||
]
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "prompt" in res.body
|
||||
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_completion_stream_vs_non_stream():
|
||||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
def test_completion_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
@@ -102,7 +102,7 @@ def test_completion_stream_with_openai_library():
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_with_openai_library():
|
||||
def test_completion_stream_with_openai_library():
|
||||
global server
|
||||
server.start()
|
||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
|
||||
@@ -72,13 +72,14 @@ class ServerProcess:
|
||||
pooling: str | None = None
|
||||
draft: int | None = None
|
||||
api_key: str | None = None
|
||||
response_format: str | None = None
|
||||
lora_files: List[str] | None = None
|
||||
disable_ctx_shift: int | None = False
|
||||
draft_min: int | None = None
|
||||
draft_max: int | None = None
|
||||
no_webui: bool | None = None
|
||||
jinja: bool | None = None
|
||||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
@@ -169,8 +170,12 @@ class ServerProcess:
|
||||
server_args.extend(["--draft-min", self.draft_min])
|
||||
if self.no_webui:
|
||||
server_args.append("--no-webui")
|
||||
if self.jinja:
|
||||
server_args.append("--jinja")
|
||||
if self.chat_template:
|
||||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"bench: starting server with: {' '.join(args)}")
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
#include "minja.hpp"
|
||||
#include "chat-template.hpp"
|
||||
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
@@ -349,7 +351,7 @@ static llama_tokens format_infill(
|
||||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
|
||||
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
|
||||
std::vector<common_chat_msg> chat;
|
||||
|
||||
for (size_t i = 0; i < messages.size(); ++i) {
|
||||
@@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||
chat.push_back({role, content});
|
||||
}
|
||||
|
||||
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
||||
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||
|
||||
return formatted_chat;
|
||||
@@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) {
|
||||
return llama_params;
|
||||
}
|
||||
|
||||
static json oaicompat_chat_completion_params_parse(
|
||||
const struct llama_model * model,
|
||||
const json & body, /* openai api json semantics */
|
||||
const std::string & chat_template) {
|
||||
static json oaicompat_completion_params_parse(
|
||||
const json & body, /* openai api json semantics */
|
||||
const common_chat_template & tmpl,
|
||||
bool use_jinja)
|
||||
{
|
||||
json llama_params;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
|
||||
if (has_tools) {
|
||||
if (use_jinja) {
|
||||
LOG_WRN("tools param is not fully supported yet\n");
|
||||
} else {
|
||||
throw std::runtime_error("tools param requires --jinja flag");
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "stop" field
|
||||
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||
@@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
if (use_jinja) {
|
||||
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
|
||||
} else {
|
||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
||||
}
|
||||
|
||||
// Handle "n" field
|
||||
int n_choices = json_value(body, "n", 1);
|
||||
if (n_choices != 1) {
|
||||
@@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse(
|
||||
}
|
||||
|
||||
// Params supported by OAI but unsupported by llama.cpp
|
||||
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
|
||||
static const std::vector<std::string> unsupported_params { "tool_choice" };
|
||||
for (const auto & param : unsupported_params) {
|
||||
if (body.contains(param)) {
|
||||
throw std::runtime_error("Unsupported param: " + param);
|
||||
|
||||
@@ -141,6 +141,7 @@
|
||||
:msg="pendingMsg"
|
||||
:key="pendingMsg.id"
|
||||
:is-generating="isGenerating"
|
||||
:show-thought-in-progress="config.showThoughtInProgress"
|
||||
:edit-user-msg-and-regenerate="() => {}"
|
||||
:regenerate-msg="() => {}"></message-bubble>
|
||||
</div>
|
||||
@@ -202,6 +203,20 @@
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
<!-- Section: Reasoning models -->
|
||||
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary class="collapse-title font-bold">Reasoning models</summary>
|
||||
<div class="collapse-content">
|
||||
<div class="flex flex-row items-center mb-2">
|
||||
<input type="checkbox" class="checkbox" v-model="config.showThoughtInProgress" />
|
||||
<span class="ml-4">Expand though process by default for generating message</span>
|
||||
</div>
|
||||
<div class="flex flex-row items-center mb-2">
|
||||
<input type="checkbox" class="checkbox" v-model="config.excludeThoughtOnReq" />
|
||||
<span class="ml-4">Exclude thought process when sending request to API (Recommended for DeepSeek-R1)</span>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
<!-- Section: Advanced config -->
|
||||
<details class="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
|
||||
<summary class="collapse-title font-bold">Advanced config</summary>
|
||||
@@ -261,7 +276,17 @@
|
||||
<span v-if="msg.content === null" class="loading loading-dots loading-md"></span>
|
||||
<!-- render message as markdown -->
|
||||
<div v-else dir="auto">
|
||||
<vue-markdown :source="msg.content"></vue-markdown>
|
||||
<details v-if="msg.role === 'assistant' && splitMsgContent.cot" class="collapse bg-base-200 collapse-arrow mb-4" :open="splitMsgContent.isThinking && showThoughtInProgress">
|
||||
<summary class="collapse-title">
|
||||
<span v-if="splitMsgContent.isThinking">
|
||||
<span v-if="isGenerating" class="loading loading-spinner loading-md mr-2" style="vertical-align: middle;"></span>
|
||||
<b>Thinking</b>
|
||||
</span>
|
||||
<b v-else>Thought Process</b>
|
||||
</summary>
|
||||
<vue-markdown :source="splitMsgContent.cot" dir="auto" class="collapse-content"></vue-markdown>
|
||||
</details>
|
||||
<vue-markdown :source="splitMsgContent.content"></vue-markdown>
|
||||
</div>
|
||||
<!-- render timings if enabled -->
|
||||
<div class="dropdown dropdown-hover dropdown-top mt-2" v-if="timings && config.showTokensPerSecond">
|
||||
|
||||
@@ -17,6 +17,11 @@ import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
|
||||
|
||||
const isDev = import.meta.env.MODE === 'development';
|
||||
|
||||
// types
|
||||
/** @typedef {{ id: number, role: 'user' | 'assistant', content: string, timings: any }} Message */
|
||||
/** @typedef {{ role: 'user' | 'assistant', content: string }} APIMessage */
|
||||
/** @typedef {{ id: string, lastModified: number, messages: Array<Message> }} Conversation */
|
||||
|
||||
// utility functions
|
||||
const isString = (x) => !!x.toLowerCase;
|
||||
const isBoolean = (x) => x === true || x === false;
|
||||
@@ -50,6 +55,8 @@ const CONFIG_DEFAULT = {
|
||||
apiKey: '',
|
||||
systemMessage: 'You are a helpful assistant.',
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
excludeThoughtOnReq: true,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
@@ -172,6 +179,7 @@ const MessageBubble = defineComponent({
|
||||
config: Object,
|
||||
msg: Object,
|
||||
isGenerating: Boolean,
|
||||
showThoughtInProgress: Boolean,
|
||||
editUserMsgAndRegenerate: Function,
|
||||
regenerateMsg: Function,
|
||||
},
|
||||
@@ -188,7 +196,31 @@ const MessageBubble = defineComponent({
|
||||
prompt_per_second: this.msg.timings.prompt_n / (this.msg.timings.prompt_ms / 1000),
|
||||
predicted_per_second: this.msg.timings.predicted_n / (this.msg.timings.predicted_ms / 1000),
|
||||
};
|
||||
}
|
||||
},
|
||||
splitMsgContent() {
|
||||
const content = this.msg.content;
|
||||
if (this.msg.role !== 'assistant') {
|
||||
return { content };
|
||||
}
|
||||
let actualContent = '';
|
||||
let cot = '';
|
||||
let isThinking = false;
|
||||
let thinkSplit = content.split('<think>', 2);
|
||||
actualContent += thinkSplit[0];
|
||||
while (thinkSplit[1] !== undefined) {
|
||||
// <think> tag found
|
||||
thinkSplit = thinkSplit[1].split('</think>', 2);
|
||||
cot += thinkSplit[0];
|
||||
isThinking = true;
|
||||
if (thinkSplit[1] !== undefined) {
|
||||
// </think> closing tag found
|
||||
isThinking = false;
|
||||
thinkSplit = thinkSplit[1].split('<think>', 2);
|
||||
actualContent += thinkSplit[0];
|
||||
}
|
||||
}
|
||||
return { content: actualContent, cot, isThinking };
|
||||
},
|
||||
},
|
||||
methods: {
|
||||
copyMsg() {
|
||||
@@ -208,7 +240,10 @@ const MessageBubble = defineComponent({
|
||||
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
|
||||
// convId is a string prefixed with 'conv-'
|
||||
const StorageUtils = {
|
||||
// manage conversations
|
||||
/**
|
||||
* manage conversations
|
||||
* @returns {Array<Conversation>}
|
||||
*/
|
||||
getAllConversations() {
|
||||
const res = [];
|
||||
for (const key in localStorage) {
|
||||
@@ -219,11 +254,19 @@ const StorageUtils = {
|
||||
res.sort((a, b) => b.lastModified - a.lastModified);
|
||||
return res;
|
||||
},
|
||||
// can return null if convId does not exist
|
||||
/**
|
||||
* can return null if convId does not exist
|
||||
* @param {string} convId
|
||||
* @returns {Conversation | null}
|
||||
*/
|
||||
getOneConversation(convId) {
|
||||
return JSON.parse(localStorage.getItem(convId) || 'null');
|
||||
},
|
||||
// if convId does not exist, create one
|
||||
/**
|
||||
* if convId does not exist, create one
|
||||
* @param {string} convId
|
||||
* @param {Message} msg
|
||||
*/
|
||||
appendMsg(convId, msg) {
|
||||
if (msg.content === null) return;
|
||||
const conv = StorageUtils.getOneConversation(convId) || {
|
||||
@@ -235,12 +278,24 @@ const StorageUtils = {
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
},
|
||||
/**
|
||||
* Get new conversation id
|
||||
* @returns {string}
|
||||
*/
|
||||
getNewConvId() {
|
||||
return `conv-${Date.now()}`;
|
||||
},
|
||||
/**
|
||||
* remove conversation by id
|
||||
* @param {string} convId
|
||||
*/
|
||||
remove(convId) {
|
||||
localStorage.removeItem(convId);
|
||||
},
|
||||
/**
|
||||
* remove all conversations
|
||||
* @param {string} convId
|
||||
*/
|
||||
filterAndKeepMsgs(convId, predicate) {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
@@ -248,6 +303,11 @@ const StorageUtils = {
|
||||
conv.lastModified = Date.now();
|
||||
localStorage.setItem(convId, JSON.stringify(conv));
|
||||
},
|
||||
/**
|
||||
* remove last message from conversation
|
||||
* @param {string} convId
|
||||
* @returns {Message | undefined}
|
||||
*/
|
||||
popMsg(convId) {
|
||||
const conv = StorageUtils.getOneConversation(convId);
|
||||
if (!conv) return;
|
||||
@@ -322,10 +382,12 @@ const mainApp = createApp({
|
||||
data() {
|
||||
return {
|
||||
conversations: StorageUtils.getAllConversations(),
|
||||
messages: [], // { id: number, role: 'user' | 'assistant', content: string }
|
||||
/** @type {Array<Message>} */
|
||||
messages: [],
|
||||
viewingConvId: StorageUtils.getNewConvId(),
|
||||
inputMsg: '',
|
||||
isGenerating: false,
|
||||
/** @type {Array<Message> | null} */
|
||||
pendingMsg: null, // the on-going message from assistant
|
||||
stopGeneration: () => {},
|
||||
selectedTheme: StorageUtils.getTheme(),
|
||||
@@ -333,6 +395,7 @@ const mainApp = createApp({
|
||||
showConfigDialog: false,
|
||||
// const
|
||||
themes: THEMES,
|
||||
/** @type {CONFIG_DEFAULT} */
|
||||
configDefault: {...CONFIG_DEFAULT},
|
||||
configInfo: {...CONFIG_INFO},
|
||||
isDev,
|
||||
@@ -425,42 +488,50 @@ const mainApp = createApp({
|
||||
this.isGenerating = true;
|
||||
|
||||
try {
|
||||
/** @type {CONFIG_DEFAULT} */
|
||||
const config = this.config;
|
||||
const abortController = new AbortController();
|
||||
this.stopGeneration = () => abortController.abort();
|
||||
/** @type {Array<APIMessage>} */
|
||||
let messages = [
|
||||
{ role: 'system', content: config.systemMessage },
|
||||
...normalizeMsgsForAPI(this.messages),
|
||||
];
|
||||
if (config.excludeThoughtOnReq) {
|
||||
messages = filterThoughtFromMsgs(messages);
|
||||
}
|
||||
if (isDev) console.log({messages});
|
||||
const params = {
|
||||
messages: [
|
||||
{ role: 'system', content: this.config.systemMessage },
|
||||
...this.messages,
|
||||
],
|
||||
messages,
|
||||
stream: true,
|
||||
cache_prompt: true,
|
||||
samplers: this.config.samplers,
|
||||
temperature: this.config.temperature,
|
||||
dynatemp_range: this.config.dynatemp_range,
|
||||
dynatemp_exponent: this.config.dynatemp_exponent,
|
||||
top_k: this.config.top_k,
|
||||
top_p: this.config.top_p,
|
||||
min_p: this.config.min_p,
|
||||
typical_p: this.config.typical_p,
|
||||
xtc_probability: this.config.xtc_probability,
|
||||
xtc_threshold: this.config.xtc_threshold,
|
||||
repeat_last_n: this.config.repeat_last_n,
|
||||
repeat_penalty: this.config.repeat_penalty,
|
||||
presence_penalty: this.config.presence_penalty,
|
||||
frequency_penalty: this.config.frequency_penalty,
|
||||
dry_multiplier: this.config.dry_multiplier,
|
||||
dry_base: this.config.dry_base,
|
||||
dry_allowed_length: this.config.dry_allowed_length,
|
||||
dry_penalty_last_n: this.config.dry_penalty_last_n,
|
||||
max_tokens: this.config.max_tokens,
|
||||
timings_per_token: !!this.config.showTokensPerSecond,
|
||||
...(this.config.custom.length ? JSON.parse(this.config.custom) : {}),
|
||||
samplers: config.samplers,
|
||||
temperature: config.temperature,
|
||||
dynatemp_range: config.dynatemp_range,
|
||||
dynatemp_exponent: config.dynatemp_exponent,
|
||||
top_k: config.top_k,
|
||||
top_p: config.top_p,
|
||||
min_p: config.min_p,
|
||||
typical_p: config.typical_p,
|
||||
xtc_probability: config.xtc_probability,
|
||||
xtc_threshold: config.xtc_threshold,
|
||||
repeat_last_n: config.repeat_last_n,
|
||||
repeat_penalty: config.repeat_penalty,
|
||||
presence_penalty: config.presence_penalty,
|
||||
frequency_penalty: config.frequency_penalty,
|
||||
dry_multiplier: config.dry_multiplier,
|
||||
dry_base: config.dry_base,
|
||||
dry_allowed_length: config.dry_allowed_length,
|
||||
dry_penalty_last_n: config.dry_penalty_last_n,
|
||||
max_tokens: config.max_tokens,
|
||||
timings_per_token: !!config.showTokensPerSecond,
|
||||
...(config.custom.length ? JSON.parse(config.custom) : {}),
|
||||
};
|
||||
const chunks = sendSSEPostRequest(`${BASE_URL}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(this.config.apiKey ? {'Authorization': `Bearer ${this.config.apiKey}`} : {})
|
||||
...(config.apiKey ? {'Authorization': `Bearer ${config.apiKey}`} : {})
|
||||
},
|
||||
body: JSON.stringify(params),
|
||||
signal: abortController.signal,
|
||||
@@ -477,7 +548,7 @@ const mainApp = createApp({
|
||||
};
|
||||
}
|
||||
const timings = chunk.timings;
|
||||
if (timings && this.config.showTokensPerSecond) {
|
||||
if (timings && config.showTokensPerSecond) {
|
||||
// only extract what's really needed, to save some space
|
||||
this.pendingMsg.timings = {
|
||||
prompt_n: timings.prompt_n,
|
||||
@@ -598,3 +669,33 @@ try {
|
||||
<button class="btn" onClick="localStorage.clear(); window.location.reload();">Clear localStorage</button>
|
||||
</div>`;
|
||||
}
|
||||
|
||||
/**
|
||||
* filter out redundant fields upon sending to API
|
||||
* @param {Array<APIMessage>} messages
|
||||
* @returns {Array<APIMessage>}
|
||||
*/
|
||||
function normalizeMsgsForAPI(messages) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* recommended for DeepsSeek-R1, filter out content between <think> and </think> tags
|
||||
* @param {Array<APIMessage>} messages
|
||||
* @returns {Array<APIMessage>}
|
||||
*/
|
||||
function filterThoughtFromMsgs(messages) {
|
||||
return messages.map((msg) => {
|
||||
return {
|
||||
role: msg.role,
|
||||
content: msg.role === 'assistant'
|
||||
? msg.content.split('</think>').at(-1).trim()
|
||||
: msg.content,
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
@@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
|
||||
|
||||
// helper function to evaluate a prompt and generate a response
|
||||
auto generate = [&](const std::string & prompt, bool is_first) {
|
||||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
|
||||
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
|
||||
GGML_ABORT("failed to tokenize the prompt\n");
|
||||
}
|
||||
|
||||
@@ -161,7 +163,7 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
const char * tmpl = llama_model_chat_template(model);
|
||||
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
||||
|
||||
// add the user input to the message list and format it
|
||||
messages.push_back({"user", strdup(user.c_str())});
|
||||
@@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// generate a response
|
||||
printf("\033[33m");
|
||||
std::string response = generate(prompt, prev_len == 0);
|
||||
std::string response = generate(prompt);
|
||||
printf("\n\033[0m");
|
||||
|
||||
// add the response to the messages
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
cmake_minimum_required(VERSION 3.12)
|
||||
project(llama-simple-cmake-pkg)
|
||||
|
||||
set(TARGET llama-simple-cmake-pkg)
|
||||
|
||||
find_package(Llama REQUIRED)
|
||||
|
||||
add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../simple/simple.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama ggml::all ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
@@ -0,0 +1,34 @@
|
||||
# llama.cpp/example/simple-cmake-pkg
|
||||
|
||||
This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree.
|
||||
|
||||
## Building
|
||||
|
||||
Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions.
|
||||
|
||||
### Considerations
|
||||
|
||||
When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), the appropriate dependencies will be searched for automatically. So, for example, when finding a package
|
||||
|
||||
### Build llama.cpp and install to llama.cpp/inst
|
||||
|
||||
```sh
|
||||
git clone https://github.com/ggerganov/llama.cpp
|
||||
cd llama.cpp
|
||||
cmake -S . -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix inst
|
||||
|
||||
### Build simple-cmake-pkg
|
||||
|
||||
```sh
|
||||
cd examples/simple-cmake-pkg
|
||||
cmake -S . -B build -DCMAKE_PREFIX_PATH=../../inst/lib/cmake
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
### Run simple-cmake-pkg
|
||||
|
||||
```sh
|
||||
./build/llama-simple-cmake-pkg -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
|
||||
```
|
||||
+75
-1
@@ -58,7 +58,8 @@ else()
|
||||
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
|
||||
endif()
|
||||
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH})
|
||||
message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF")
|
||||
set(GGML_NATIVE_DEFAULT OFF)
|
||||
else()
|
||||
set(GGML_NATIVE_DEFAULT ON)
|
||||
@@ -153,6 +154,8 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
option(GGML_HIP "ggml: use HIP" OFF)
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
@@ -264,3 +267,74 @@ if (GGML_STANDALONE)
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc
|
||||
DESTINATION share/pkgconfig)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Create CMake package
|
||||
#
|
||||
|
||||
# Generate version info based on git commit.
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
|
||||
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_NUMBER
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
if(GGML_BUILD_NUMBER EQUAL 1)
|
||||
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
OUTPUT_VARIABLE GGML_BUILD_COMMIT
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
# Capture variables prefixed with GGML_.
|
||||
|
||||
set(variable_set_statements
|
||||
"
|
||||
####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() #######
|
||||
####### Any changes to this file will be overwritten by the next CMake run #######
|
||||
|
||||
")
|
||||
|
||||
set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS})
|
||||
|
||||
get_cmake_property(all_variables VARIABLES)
|
||||
foreach(variable_name IN LISTS all_variables)
|
||||
if(variable_name MATCHES "^GGML_")
|
||||
string(REPLACE ";" "\\;"
|
||||
variable_value "${${variable_name}}")
|
||||
|
||||
set(variable_set_statements
|
||||
"${variable_set_statements}set(${variable_name} \"${variable_value}\")\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GGML_VARIABLES_EXPANDED ${variable_set_statements})
|
||||
|
||||
# Create the CMake package and set install location.
|
||||
|
||||
set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER})
|
||||
set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files")
|
||||
set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files")
|
||||
set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files")
|
||||
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml
|
||||
PATH_VARS GGML_INCLUDE_INSTALL_DIR
|
||||
GGML_LIB_INSTALL_DIR
|
||||
GGML_BIN_INSTALL_DIR)
|
||||
|
||||
write_basic_package_version_file(
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
VERSION ${GGML_INSTALL_VERSION}
|
||||
COMPATIBILITY SameMajorVersion)
|
||||
|
||||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml)
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
|
||||
@GGML_VARIABLES_EXPANDED@
|
||||
|
||||
@PACKAGE_INIT@
|
||||
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
find_library(GGML_LIBRARY ggml
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_LIBRARY}")
|
||||
|
||||
find_library(GGML_BASE_LIBRARY ggml-base
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
add_library(ggml::ggml-base UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::ggml-base
|
||||
PROPERTIES
|
||||
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")
|
||||
|
||||
if (NOT GGML_SHARED_LIB)
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
find_library(memkind memkind REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind)
|
||||
endif()
|
||||
|
||||
if (GGML_BLAS)
|
||||
find_package(BLAS REQUIRED)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES})
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
|
||||
list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES
|
||||
${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK})
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan REQUIRED)
|
||||
list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP)
|
||||
find_package(hip REQUIRED)
|
||||
find_package(hipblas REQUIRED)
|
||||
find_package(rocblas REQUIRED)
|
||||
list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas)
|
||||
endif()
|
||||
|
||||
if (GGML_SYCL)
|
||||
find_package(DNNL)
|
||||
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl)
|
||||
endif()
|
||||
if (WIN32)
|
||||
find_package(IntelSYCL REQUIRED)
|
||||
find_package(MKL REQUIRED)
|
||||
list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(_ggml_all_targets "")
|
||||
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
|
||||
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
|
||||
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)
|
||||
|
||||
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
|
||||
REQUIRED
|
||||
HINTS ${GGML_LIB_DIR}
|
||||
NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")
|
||||
|
||||
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
|
||||
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
|
||||
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
|
||||
INTERFACE_COMPILE_FEATURES c_std_90
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
|
||||
if(is_cpu_variant)
|
||||
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(GGML_CPU_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base")
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
|
||||
|
||||
if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
|
||||
set_target_properties(ggml::${_ggml_backend}
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
|
||||
endforeach()
|
||||
|
||||
add_library(ggml::all INTERFACE IMPORTED)
|
||||
set_target_properties(ggml::all
|
||||
PROPERTIES
|
||||
INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}")
|
||||
|
||||
check_required_components(ggml)
|
||||
+12
-1
@@ -250,6 +250,17 @@ function(ggml_add_backend_library backend)
|
||||
target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
|
||||
target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED)
|
||||
endif()
|
||||
|
||||
if(NOT GGML_AVAILABLE_BACKENDS)
|
||||
set(GGML_AVAILABLE_BACKENDS "${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
else()
|
||||
list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
|
||||
if(has_backend EQUAL -1)
|
||||
set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
|
||||
CACHE INTERNAL "List of backends for cmake package")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
function(ggml_add_backend backend)
|
||||
@@ -297,7 +308,7 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
else ()
|
||||
elseif (GGML_CPU)
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1302,7 +1302,7 @@ struct ggml_threadpool {
|
||||
// these are atomic as an annotation for thread-sanitizer
|
||||
atomic_bool stop; // Used for stopping the threadpool altogether
|
||||
atomic_bool pause; // Used for pausing the threadpool or individual threads
|
||||
atomic_bool abort; // Used for aborting processing of a graph
|
||||
atomic_int abort; // Used for aborting processing of a graph
|
||||
|
||||
struct ggml_compute_state * workers; // per thread state
|
||||
int n_threads_max; // number of threads in the pool
|
||||
@@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
|
||||
}
|
||||
@@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
|
||||
|
||||
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
ggml_vec_mad_f32(ne0, d, s0, *s1);
|
||||
}
|
||||
@@ -13851,14 +13851,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.threadpool=*/ tp,
|
||||
};
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||
tp->abort = true;
|
||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
||||
tp->ec = GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
@@ -14031,7 +14031,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
|
||||
threadpool->current_chunk = 0;
|
||||
threadpool->stop = false;
|
||||
threadpool->pause = tpp->paused;
|
||||
threadpool->abort = false;
|
||||
threadpool->abort = -1;
|
||||
threadpool->workers = NULL;
|
||||
threadpool->n_threads_max = tpp->n_threads;
|
||||
threadpool->n_threads_cur = tpp->n_threads;
|
||||
@@ -14110,7 +14110,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
|
||||
threadpool->cgraph = cgraph;
|
||||
threadpool->cplan = cplan;
|
||||
threadpool->current_chunk = 0;
|
||||
threadpool->abort = false;
|
||||
threadpool->abort = -1;
|
||||
threadpool->ec = GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
@@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
|
||||
case GGML_OP_OUT_PROD:
|
||||
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
|
||||
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
|
||||
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
||||
|
||||
template <typename T>
|
||||
static __global__ void k_repeat_back(
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
|
||||
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
|
||||
|
||||
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
|
||||
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
|
||||
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
|
||||
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
|
||||
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
|
||||
const int64_t tid2 = tid23 % ne2;
|
||||
const int64_t tid3 = tid23 / ne2;
|
||||
|
||||
if (tid0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
T sum = 0;
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
|
||||
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
|
||||
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
|
||||
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
|
||||
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
|
||||
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
|
||||
|
||||
template <typename T>
|
||||
static void repeat_back_cuda(
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
|
||||
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
|
||||
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
|
||||
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
|
||||
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
|
||||
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
template<class op>
|
||||
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_can_repeat(dst, src0));
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
GGML_ASSERT(src0->ne[3] == 1);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
GGML_ASSERT(dst->ne[3] == 1);
|
||||
GGML_ASSERT(ne2*ne3 <= (1 << 15));
|
||||
|
||||
const size_t ts = ggml_type_size(src0->type);
|
||||
const size_t s00 = nb00 / ts;
|
||||
const size_t s01 = nb01 / ts;
|
||||
const size_t s02 = nb02 / ts;
|
||||
const size_t s03 = nb03 / ts;
|
||||
|
||||
switch (dst->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false);
|
||||
|
||||
@@ -46,20 +46,20 @@
|
||||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 1000000
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
|
||||
// GCN/CNDA, wave size is 64
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
|
||||
|
||||
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
|
||||
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
@@ -131,6 +131,10 @@ typedef float dfloat; // dequantize float
|
||||
typedef float2 dfloat2;
|
||||
#endif // GGML_CUDA_F16
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
#define GGML_USE_VMM
|
||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
|
||||
|
||||
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
|
||||
@@ -588,7 +592,7 @@ struct ggml_tensor_extra_gpu {
|
||||
};
|
||||
|
||||
|
||||
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
|
||||
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
|
||||
#define USE_CUDA_GRAPH
|
||||
#endif
|
||||
|
||||
|
||||
+164
-49
@@ -42,6 +42,7 @@
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <charconv>
|
||||
#include <cinttypes>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
@@ -62,7 +63,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
[[noreturn]]
|
||||
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
|
||||
int id = -1; // in case cudaGetDevice fails
|
||||
cudaGetDevice(&id);
|
||||
(void)cudaGetDevice(&id);
|
||||
|
||||
GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
|
||||
GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
|
||||
@@ -119,12 +120,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
static int ggml_cuda_parse_id(char devName[]) {
|
||||
// A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
|
||||
// these values are not stable so this is susceptible to breakage
|
||||
// https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
|
||||
int archMajor = 0x0;
|
||||
int archMinor = 0x0;
|
||||
int archNum = GGML_CUDA_CC_OFFSET_AMD;
|
||||
int archLen = strlen(devName);
|
||||
char archName[archLen + 1];
|
||||
|
||||
// strip leading 'gfx' while copying into our buffer
|
||||
if (archLen > 3) {
|
||||
strcpy(archName, &devName[3]);
|
||||
archLen -= 3;
|
||||
}
|
||||
|
||||
// trim trailing :xnack- or :sramecc- statuses
|
||||
archLen = strcspn(archName, ":");
|
||||
archName[archLen] = '\0';
|
||||
|
||||
// tease out the version information
|
||||
if (archLen > 8) {
|
||||
// versions labeled generic use '-' as delimiter
|
||||
// strip the trailing "-generic" then iterate through what remains
|
||||
if ((strstr(archName, "-generic"))) {
|
||||
archName[archLen - 8] = '\0';
|
||||
char * pch;
|
||||
if ((pch = strtok(archName, "-"))) {
|
||||
archMajor = (int)strtoul(pch, 0, 16);
|
||||
if ((pch = strtok(NULL, "-"))) {
|
||||
archMinor = 0x10 * (int)strtoul(pch, 0, 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (archLen >= 3) {
|
||||
// last two digits should be the minor * 0x10 + stepping
|
||||
archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
|
||||
archName[archLen - 2] = '\0';
|
||||
|
||||
// only the major version remains
|
||||
archMajor = (int)strtoul(archName, 0, 16);
|
||||
}
|
||||
archNum += archMajor * 0x100;
|
||||
archNum += archMinor;
|
||||
return archNum;
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
static ggml_cuda_device_info ggml_cuda_init() {
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
||||
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
|
||||
rocblas_initialize();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
{
|
||||
int major_version = 0;
|
||||
size_t version_length = 0;
|
||||
if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
|
||||
std::string version(version_length, '\0');
|
||||
if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
|
||||
version.resize(::strlen(version.c_str()));
|
||||
int parsed_value = 0;
|
||||
if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) {
|
||||
major_version = parsed_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (major_version < 4) {
|
||||
GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
|
||||
rocblas_initialize();
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
ggml_cuda_device_info info = {};
|
||||
@@ -152,7 +219,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
@@ -164,12 +231,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
alloc_prop.location.id = id;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
info.devices[id].vmm = !!device_vmm;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
|
||||
info.default_tensor_split[id] = total_vram;
|
||||
total_vram += prop.totalGlobalMem;
|
||||
@@ -178,10 +244,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
|
||||
|
||||
info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
|
||||
if ((info.devices[id].cc & 0xff00) == 0x0) {
|
||||
GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
|
||||
id, prop.name, prop.gcnArchName, prop.major, prop.minor);
|
||||
|
||||
// Fallback to prop.major and prop.minor
|
||||
if (prop.major > 0) {
|
||||
info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
|
||||
info.devices[id].cc += prop.minor * 0x10;
|
||||
}
|
||||
}
|
||||
GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s\n",
|
||||
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no");
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
|
||||
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
}
|
||||
|
||||
@@ -300,7 +381,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
@@ -309,6 +390,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
size_t pool_used = 0;
|
||||
size_t pool_size = 0;
|
||||
size_t granularity;
|
||||
#if defined(GGML_USE_HIP)
|
||||
std::vector<std::pair<CUdeviceptr, size_t>> mappings;
|
||||
#endif
|
||||
|
||||
explicit ggml_cuda_pool_vmm(int device) :
|
||||
device(device),
|
||||
@@ -317,7 +401,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
|
||||
~ggml_cuda_pool_vmm() {
|
||||
if (pool_addr != 0) {
|
||||
#if defined(GGML_USE_HIP)
|
||||
// Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
|
||||
for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
|
||||
CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
|
||||
}
|
||||
#else
|
||||
CU_CHECK(cuMemUnmap(pool_addr, pool_size));
|
||||
#endif
|
||||
CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
|
||||
}
|
||||
}
|
||||
@@ -350,7 +441,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
}
|
||||
|
||||
// map at the end of the pool
|
||||
CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
|
||||
CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
|
||||
CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
|
||||
#if defined(GGML_USE_HIP)
|
||||
mappings.push_back({start_ptr, reserve_size});
|
||||
#endif
|
||||
|
||||
// the memory allocation handle is no longer needed after mapping
|
||||
CU_CHECK(cuMemRelease(handle));
|
||||
@@ -360,7 +455,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
access.location.id = device;
|
||||
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
|
||||
CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
|
||||
|
||||
// add to the pool
|
||||
pool_size += reserve_size;
|
||||
@@ -372,7 +467,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
|
||||
GGML_ASSERT(pool_addr != 0);
|
||||
|
||||
void * ptr = (void *) (pool_addr + pool_used);
|
||||
void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
|
||||
*actual_size = size;
|
||||
pool_used += size;
|
||||
|
||||
@@ -391,17 +486,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
pool_used -= size;
|
||||
|
||||
// all deallocations must be in reverse order of the allocations
|
||||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||
GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if defined(GGML_USE_VMM)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
|
||||
#endif // defined(GGML_USE_VMM)
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
@@ -547,7 +642,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
|
||||
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
@@ -962,7 +1057,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
|
||||
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size / 1024.0 / 1024.0, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
@@ -1082,7 +1177,9 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const int compute_capability = ggml_cuda_info().devices[id].cc;
|
||||
|
||||
if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
|
||||
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
|
||||
if (src0->type != GGML_TYPE_F16) {
|
||||
@@ -1103,28 +1200,38 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
||||
}
|
||||
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
if (compute_capability == GGML_CUDA_CC_CDNA) {
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
const half alpha_f16 = 1.0f;
|
||||
const half beta_f16 = 0.0f;
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
row_diff, src1_ncols, ne10,
|
||||
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
}
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
|
||||
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
|
||||
@@ -1197,7 +1304,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||
CUDA_CHECK(err);
|
||||
} else {
|
||||
// reset the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
} else {
|
||||
cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
|
||||
@@ -1205,7 +1312,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||
CUDA_CHECK(err);
|
||||
} else {
|
||||
// reset the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1613,10 +1720,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
// dst strides
|
||||
size_t nbd2 = dst->nb[2];
|
||||
size_t nbd3 = dst->nb[3];
|
||||
@@ -1645,6 +1748,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
@@ -2438,7 +2547,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
|
||||
if (stat == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
@@ -2493,14 +2602,20 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
hipGraphNode_t errorNode;
|
||||
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
#else
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
#endif
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||
#endif
|
||||
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
@@ -2728,7 +2843,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
|
||||
GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
|
||||
size / 1024.0 / 1024.0, cudaGetErrorString(err));
|
||||
@@ -2748,7 +2863,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
|
||||
cudaError_t err = cudaHostUnregister(buffer);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
cudaGetLastError();
|
||||
(void)cudaGetLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3002,7 +3117,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
|
||||
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
|
||||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
@@ -3216,7 +3331,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
|
||||
features.push_back({ "FORCE_CUBLAS", "1" });
|
||||
#endif
|
||||
|
||||
#ifdef GGML_CUDA_NO_VMM
|
||||
#ifndef GGML_USE_VMM
|
||||
features.push_back({ "NO_VMM", "1" });
|
||||
#endif
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
|
||||
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
nwarps = 4;
|
||||
@@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
@@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(handle, stream));
|
||||
|
||||
const int64_t lda = nb01 / sizeof(float);
|
||||
const int64_t ldc = nb1 / sizeof(float);
|
||||
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
@@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
||||
@@ -118,6 +124,9 @@ static __global__ void soft_max_f32(
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
|
||||
static __global__ void soft_max_back_f32(
|
||||
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
|
||||
|
||||
Vendored
+43
@@ -19,6 +19,12 @@
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F HIPBLAS_R_16F
|
||||
#define CUDA_R_32F HIPBLAS_R_32F
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||
#define cublasCreate hipblasCreate
|
||||
@@ -74,6 +80,21 @@
|
||||
#define cudaMemGetInfo hipMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cuDeviceGet hipDeviceGet
|
||||
#define CUdevice hipDevice_t
|
||||
#define CUdeviceptr hipDeviceptr_t
|
||||
#define cuMemUnmap hipMemUnmap
|
||||
#define CUmemAccessDesc hipMemAccessDesc
|
||||
#define cuMemAddressFree hipMemAddressFree
|
||||
#define cuMemRelease hipMemRelease
|
||||
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
||||
#define cuMemCreate hipMemCreate
|
||||
#define cuMemAddressReserve hipMemAddressReserve
|
||||
#define cuMemMap hipMemMap
|
||||
#define cuMemSetAccess hipMemSetAccess
|
||||
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
||||
#define CUmemAllocationProp hipMemAllocationProp
|
||||
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamDestroy hipStreamDestroy
|
||||
#define cudaStreamFireAndForget hipStreamFireAndForget
|
||||
@@ -81,6 +102,28 @@
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
#define cudaStreamEndCapture hipStreamEndCapture
|
||||
#define cudaGraphDestroy hipGraphDestroy
|
||||
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
||||
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
||||
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
||||
#define cudaGraphNodeGetType hipGraphNodeGetType
|
||||
#define cudaGraphGetNodes hipGraphGetNodes
|
||||
#define cudaGraphExecUpdate hipGraphExecUpdate
|
||||
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture hipStreamBeginCapture
|
||||
#define cudaGraph_t hipGraph_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
|
||||
@@ -92,6 +92,14 @@ if (GGML_CUDA_NO_PEER_COPY)
|
||||
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_GRAPHS)
|
||||
add_compile_definitions(GGML_HIP_GRAPHS)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_NO_VMM)
|
||||
add_compile_definitions(GGML_HIP_NO_VMM)
|
||||
endif()
|
||||
|
||||
if (CXX_IS_HIPCC)
|
||||
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
||||
target_link_libraries(ggml-hip PRIVATE hip::device)
|
||||
|
||||
@@ -19,7 +19,10 @@
|
||||
// max number of MTLCommandBuffer used to submit a graph for processing
|
||||
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
// create residency sets only on macOS >= 15.0
|
||||
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000
|
||||
#define GGML_METAL_HAS_RESIDENCY_SETS 1
|
||||
#endif
|
||||
|
||||
// globals
|
||||
|
||||
@@ -39,6 +42,7 @@ static struct ggml_backend_metal_device_context {
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
bool has_simdgroup_mm;
|
||||
bool has_residency_sets;
|
||||
bool has_bfloat;
|
||||
bool use_bfloat;
|
||||
|
||||
@@ -48,6 +52,7 @@ static struct ggml_backend_metal_device_context {
|
||||
/*.mtl_device_ref_count =*/ 0,
|
||||
/*.has_simdgroup_reduction =*/ false,
|
||||
/*.has_simdgroup_mm =*/ false,
|
||||
/*.has_residency_sets =*/ false,
|
||||
/*.has_bfloat =*/ false,
|
||||
/*.use_bfloat =*/ false,
|
||||
/*.name =*/ "",
|
||||
@@ -59,12 +64,18 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
}
|
||||
|
||||
if (ctx->mtl_device) {
|
||||
ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
|
||||
ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
|
||||
#endif
|
||||
|
||||
ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
@@ -90,8 +101,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
ctx->mtl_device_ref_count--;
|
||||
|
||||
if (ctx->mtl_device_ref_count == 0) {
|
||||
[ctx->mtl_device release];
|
||||
ctx->mtl_device = nil;
|
||||
if (ctx->mtl_device) {
|
||||
[ctx->mtl_device release];
|
||||
ctx->mtl_device = nil;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -483,6 +496,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||
|
||||
ctx->queue = [device newCommandQueue];
|
||||
if (ctx->queue == nil) {
|
||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
id<MTLLibrary> metal_library;
|
||||
@@ -649,6 +667,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
|
||||
GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
||||
@@ -1035,8 +1054,70 @@ struct ggml_backend_metal_buffer_context {
|
||||
// multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
|
||||
int n_buffers;
|
||||
struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
||||
|
||||
// optional MTLResidencySet
|
||||
id rset;
|
||||
};
|
||||
|
||||
// rset init
|
||||
static bool ggml_backend_metal_buffer_rset_init(
|
||||
struct ggml_backend_metal_buffer_context * ctx,
|
||||
struct ggml_backend_metal_device_context * ctx_dev,
|
||||
id<MTLDevice> device) {
|
||||
ctx->rset = nil;
|
||||
|
||||
if (!ctx_dev->has_residency_sets) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
if (@available(macOS 15.0, *)) {
|
||||
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
|
||||
desc.label = @"ggml_backend_metal";
|
||||
desc.initialCapacity = ctx->n_buffers;
|
||||
|
||||
NSError * error;
|
||||
ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
[desc release];
|
||||
return false;
|
||||
}
|
||||
|
||||
[desc release];
|
||||
|
||||
for (int i = 0; i < ctx->n_buffers; i++) {
|
||||
[ctx->rset addAllocation:ctx->buffers[i].metal];
|
||||
}
|
||||
|
||||
[ctx->rset commit];
|
||||
[ctx->rset requestResidency];
|
||||
|
||||
return true;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(ctx_dev);
|
||||
GGML_UNUSED(device);
|
||||
#endif
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// rset free
|
||||
static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
if (@available(macOS 15.0, *)) {
|
||||
if (ctx->rset) {
|
||||
[ctx->rset endResidency];
|
||||
[ctx->rset removeAllAllocations];
|
||||
[ctx->rset release];
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(ctx);
|
||||
#endif
|
||||
}
|
||||
|
||||
// finds the Metal buffer that contains the tensor data on the GPU device
|
||||
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
||||
// Metal buffer based on the host memory pointer
|
||||
@@ -4176,6 +4257,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||
for (int i = 0; i < ctx->n_buffers; i++) {
|
||||
[ctx->buffers[i].metal release];
|
||||
}
|
||||
|
||||
ggml_backend_metal_buffer_rset_free(ctx);
|
||||
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
||||
|
||||
if (ctx->owned) {
|
||||
@@ -4198,19 +4281,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
|
||||
UNUSED(buffer);
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
|
||||
UNUSED(buffer);
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
|
||||
UNUSED(buffer);
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
||||
@@ -4220,7 +4303,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
|
||||
}
|
||||
return false;
|
||||
|
||||
UNUSED(buffer);
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
@@ -4246,7 +4329,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
|
||||
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal";
|
||||
|
||||
UNUSED(buft);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
||||
@@ -4270,8 +4353,8 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t s
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
UNUSED(device);
|
||||
UNUSED(size_aligned);
|
||||
GGML_UNUSED(device);
|
||||
GGML_UNUSED(size_aligned);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
@@ -4284,7 +4367,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
size_aligned += (size_page - (size_aligned % size_page));
|
||||
}
|
||||
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
||||
ctx->all_size = size_aligned;
|
||||
@@ -4307,7 +4391,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(buft->device->context);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -4318,7 +4409,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
UNUSED(buft);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
@@ -4328,13 +4419,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
|
||||
|
||||
return max_size;
|
||||
|
||||
UNUSED(buft);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return true;
|
||||
|
||||
UNUSED(buft);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
||||
@@ -4357,7 +4448,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
|
||||
static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return "Metal_Mapped";
|
||||
|
||||
UNUSED(buft);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
|
||||
@@ -4400,7 +4491,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||
size_aligned += (size_page - (size_aligned % size_page));
|
||||
}
|
||||
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
||||
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
// the buffer fits into the max buffer size allowed by the device
|
||||
if (size_aligned <= device.maxBufferLength) {
|
||||
@@ -4453,6 +4545,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||
}
|
||||
|
||||
@@ -4461,7 +4560,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||
static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||
return "Metal";
|
||||
|
||||
UNUSED(backend);
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||
@@ -4766,6 +4865,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||
}
|
||||
|
||||
@@ -4779,7 +4885,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
|
||||
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
||||
|
||||
UNUSED(dev);
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
||||
|
||||
@@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
device const half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
@@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
if (tiisg == 0) {
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||
dst_f32[first_row + row] = sumf1[row];
|
||||
}
|
||||
}
|
||||
@@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
@@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
|
||||
const int row = 2*r0 + sgitg;
|
||||
|
||||
if (row >= args.ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
@@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
@@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
@@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum * 0.5f;
|
||||
@@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
@@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
@@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
|
||||
@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
|
||||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
std::shared_ptr<socket_t> sock;
|
||||
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
||||
void * base_ptr;
|
||||
uint64_t remote_ptr;
|
||||
};
|
||||
|
||||
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
|
||||
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
||||
return ctx->base_cache[buffer];
|
||||
if (ctx->base_ptr != nullptr) {
|
||||
return ctx->base_ptr;
|
||||
}
|
||||
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
||||
rpc_msg_buffer_get_base_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
||||
ctx->base_cache[buffer] = base_ptr;
|
||||
return base_ptr;
|
||||
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
||||
return ctx->base_ptr;
|
||||
}
|
||||
|
||||
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
||||
if (response.remote_ptr != 0) {
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||
ggml_backend_rpc_buffer_interface,
|
||||
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
||||
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
||||
response.remote_size);
|
||||
return buffer;
|
||||
} else {
|
||||
|
||||
@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
||||
}
|
||||
|
||||
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
|
||||
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_sycl_soft_max(ctx, dst);
|
||||
ggml_sycl_op_soft_max(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "norm.hpp"
|
||||
#include "softmax.hpp"
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
@@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||
slope = sycl::pow(base, float(exp));
|
||||
}
|
||||
|
||||
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
||||
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
||||
float max_val = -INFINITY;
|
||||
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
@@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
||||
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = sycl::max(max_val, val);
|
||||
@@ -65,7 +65,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
max_val = buf[lane_id];
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||
}
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
}
|
||||
@@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
||||
}
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, queue_ptr stream) {
|
||||
@@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float *
|
||||
});
|
||||
}
|
||||
|
||||
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||
template<typename T>
|
||||
static void soft_max_f32_sycl(const float * x, const T * mask,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
queue_ptr stream, int device) {
|
||||
@@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(dst->src[0]);
|
||||
const int64_t nrows_y = dst->src[0]->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
@@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
ggml_sycl_set_device(ctx.device);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
||||
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
||||
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
||||
main_stream, ctx.device);
|
||||
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
} else {
|
||||
/* mask unavailable */
|
||||
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,6 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_SOFTMAX_HPP
|
||||
|
||||
@@ -29,8 +29,6 @@
|
||||
|
||||
#include "ggml-vulkan-shaders.hpp"
|
||||
|
||||
#define VK_API_VERSION VK_API_VERSION_1_2
|
||||
|
||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||
|
||||
#define VK_VENDOR_ID_AMD 0x1002
|
||||
@@ -87,6 +85,10 @@ struct vk_pipeline_struct {
|
||||
uint32_t parameter_count;
|
||||
std::array<uint32_t, 3> wg_denoms;
|
||||
uint32_t align;
|
||||
// set to true to request the pipeline is compiled after the dryrun
|
||||
bool needed {};
|
||||
// set to true when the shader has been compiled
|
||||
bool compiled {};
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
|
||||
@@ -188,8 +190,11 @@ struct vk_device_struct {
|
||||
bool mul_mat_id_m;
|
||||
bool mul_mat_id_s;
|
||||
|
||||
vk_matmul_pipeline pipeline_matmul_f32;
|
||||
vk_matmul_pipeline pipeline_matmul_f32_f16;
|
||||
// set to true to indicate that some shaders need to be compiled after the dryrun
|
||||
bool need_compiles {};
|
||||
|
||||
vk_matmul_pipeline pipeline_matmul_f32 {};
|
||||
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
||||
vk_matmul_pipeline2 pipeline_matmul_f16;
|
||||
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
||||
vk_pipeline pipeline_matmul_split_k_reduce;
|
||||
@@ -197,7 +202,7 @@ struct vk_device_struct {
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
||||
|
||||
vk_matmul_pipeline pipeline_matmul_id_f32;
|
||||
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
||||
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
||||
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
||||
|
||||
@@ -769,22 +774,15 @@ static uint32_t compile_count = 0;
|
||||
static std::mutex compile_count_mutex;
|
||||
static std::condition_variable compile_count_cond;
|
||||
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
||||
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
||||
uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
|
||||
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
|
||||
", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
|
||||
static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
|
||||
uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
|
||||
bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
|
||||
", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
|
||||
disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
|
||||
GGML_ASSERT(parameter_count > 0);
|
||||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
||||
|
||||
pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
pipeline->name = name;
|
||||
pipeline->parameter_count = parameter_count;
|
||||
pipeline->push_constant_size = push_constant_size;
|
||||
pipeline->wg_denoms = wg_denoms;
|
||||
pipeline->align = align;
|
||||
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
|
||||
pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
|
||||
|
||||
@@ -866,7 +864,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
compute_pipeline_create_info.setPNext(&rci);
|
||||
}
|
||||
|
||||
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
|
||||
try {
|
||||
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
|
||||
} catch (const vk::SystemError& e) {
|
||||
std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
|
||||
std::cerr << "ggml_vulkan: " << e.what() << std::endl;
|
||||
throw e;
|
||||
}
|
||||
pipeline->compiled = true;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(device->mutex);
|
||||
@@ -877,12 +882,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
std::lock_guard<std::mutex> guard(compile_count_mutex);
|
||||
assert(compile_count > 0);
|
||||
compile_count--;
|
||||
|
||||
// "Progress bar" for shader compiles
|
||||
static uint32_t total_compile_count = 0;
|
||||
if ((total_compile_count++ % 10) == 0) {
|
||||
std::cerr << ".";
|
||||
}
|
||||
}
|
||||
compile_count_cond.notify_all();
|
||||
}
|
||||
@@ -908,6 +907,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline)
|
||||
static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) {
|
||||
VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
|
||||
device->pipeline_descriptor_set_requirements[pipeline->name] += n;
|
||||
if (!pipeline->compiled) {
|
||||
pipeline->needed = true;
|
||||
device->need_compiles = true;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) {
|
||||
@@ -1390,8 +1393,6 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
std::cerr << "ggml_vulkan: Compiling shaders";
|
||||
|
||||
// some shaders have a minimum subgroup size
|
||||
const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
|
||||
const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
|
||||
@@ -1529,15 +1530,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
}
|
||||
|
||||
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
|
||||
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
if (!device->pipeline_matmul_f32) {
|
||||
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
}
|
||||
if (!device->pipeline_matmul_f32_f16) {
|
||||
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
}
|
||||
if (!device->pipeline_matmul_id_f32) {
|
||||
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
||||
}
|
||||
|
||||
std::vector<std::future<void>> compiles;
|
||||
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
||||
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
|
||||
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
|
||||
|
||||
if (!pipeline) {
|
||||
pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
pipeline->name = name;
|
||||
pipeline->parameter_count = parameter_count;
|
||||
pipeline->push_constant_size = push_constant_size;
|
||||
pipeline->wg_denoms = wg_denoms;
|
||||
pipeline->align = align;
|
||||
}
|
||||
|
||||
if (!pipeline->needed || pipeline->compiled) {
|
||||
return;
|
||||
}
|
||||
{
|
||||
// wait until fewer than N compiles are in progress
|
||||
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
|
||||
@@ -1547,8 +1566,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
compile_count++;
|
||||
}
|
||||
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
|
||||
parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
@@ -1597,6 +1616,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
||||
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
||||
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
||||
#undef CREATE_FA
|
||||
|
||||
@@ -1614,11 +1638,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
|
||||
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||
|
||||
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
@@ -1629,23 +1649,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
||||
|
||||
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
||||
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
||||
#undef CREATE_MM
|
||||
#undef CREATE_MM2
|
||||
} else
|
||||
@@ -1682,31 +1709,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
if (device->coopmat_acc_f16_support) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
} else {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
|
||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
||||
@@ -1716,31 +1753,41 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
|
||||
if (device->coopmat_acc_f16_support) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
} else {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
}
|
||||
}
|
||||
#undef CREATE_MM2
|
||||
@@ -1784,7 +1831,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
||||
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
||||
@@ -1803,7 +1855,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
}
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MM
|
||||
@@ -1839,7 +1896,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
|
||||
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
|
||||
@@ -1858,7 +1920,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
}
|
||||
#undef CREATE_MM
|
||||
}
|
||||
@@ -1889,7 +1956,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||
@@ -1903,7 +1975,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||
@@ -1918,7 +1995,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
||||
|
||||
// dequant shaders
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
@@ -1932,7 +2014,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||
|
||||
// get_rows
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -1942,7 +2029,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
||||
@@ -1951,7 +2043,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
|
||||
@@ -2021,7 +2118,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
|
||||
@@ -2059,7 +2156,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
std::cerr << "Done!" << std::endl;
|
||||
device->need_compiles = false;
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
|
||||
@@ -2287,6 +2384,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
||||
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
||||
if (maintenance4_support) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
|
||||
last_struct = (VkBaseOutStructure *)&maint4_features;
|
||||
device_extensions.push_back("VK_KHR_maintenance4");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
||||
@@ -2662,7 +2767,14 @@ void ggml_vk_instance_init() {
|
||||
|
||||
vk_instance_initialized = true;
|
||||
|
||||
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
|
||||
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||
|
||||
if (api_version < VK_API_VERSION_1_2) {
|
||||
std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
|
||||
|
||||
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
|
||||
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
|
||||
@@ -2863,6 +2975,11 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -2911,6 +3028,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -2942,6 +3064,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -2972,7 +3099,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(src1_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
||||
|
||||
switch (src0_type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -2985,6 +3112,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -3011,6 +3143,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -3812,8 +3949,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
src1_uma = d_Qy != nullptr;
|
||||
}
|
||||
|
||||
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
||||
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
|
||||
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
|
||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||
!ggml_vk_dim01_contiguous(src0);
|
||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||
!ggml_vk_dim01_contiguous(src1);
|
||||
|
||||
@@ -4393,8 +4531,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
ids_uma = d_ids != nullptr;
|
||||
}
|
||||
|
||||
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
|
||||
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
||||
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
|
||||
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
||||
!ggml_vk_dim01_contiguous(src0);
|
||||
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
||||
!ggml_vk_dim01_contiguous(src1);
|
||||
|
||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||
|
||||
@@ -4404,7 +4545,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||
|
||||
if (qx_needs_dequant) {
|
||||
GGML_ABORT("fatal error");
|
||||
// Fall back to dequant + f16 mulmat
|
||||
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
|
||||
}
|
||||
|
||||
// Not implemented
|
||||
@@ -7645,6 +7787,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
||||
}
|
||||
if (ctx->device->need_compiles) {
|
||||
ggml_vk_load_shaders(ctx->device);
|
||||
}
|
||||
ggml_vk_preallocate_buffers(ctx);
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
||||
|
||||
@@ -7872,6 +8017,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -7940,6 +8090,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
//case GGML_TYPE_Q4_K:
|
||||
//case GGML_TYPE_Q5_K:
|
||||
//case GGML_TYPE_Q6_K:
|
||||
//case GGML_TYPE_IQ2_XXS:
|
||||
//case GGML_TYPE_IQ2_XS:
|
||||
//case GGML_TYPE_IQ2_S:
|
||||
//case GGML_TYPE_IQ3_XXS:
|
||||
//case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
break;
|
||||
default:
|
||||
@@ -7957,6 +8112,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return true;
|
||||
default:
|
||||
|
||||
@@ -12,8 +12,8 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -217,8 +217,8 @@ void quantize(uint dst_idx, uint src_idx)
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
if (gl_LocalInvocationIndex.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -88,6 +88,222 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = (iqs / 8) % 4;
|
||||
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
|
||||
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
|
||||
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
|
||||
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
|
||||
const float db = 0.25 * (0.5 + (signs >> 28));
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
return db * vec2(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = (iqs / 8) % 4;
|
||||
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
|
||||
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
|
||||
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
|
||||
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
|
||||
const float db = 0.25 * (0.5 + (signs >> 28));
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
bool sign2 = (sign & 4) != 0;
|
||||
bool sign3 = (sign & 8) != 0;
|
||||
return db * vec4(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0),
|
||||
grid.z * (sign2 ? -1.0 : 1.0),
|
||||
grid.w * (sign3 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XS)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
|
||||
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
|
||||
const float db = 0.25 * (0.5 + scale);
|
||||
const uint sign7 = qs >> 9;
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
return db * vec2(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
|
||||
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
|
||||
const float db = 0.25 * (0.5 + scale);
|
||||
const uint sign7 = qs >> 9;
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
bool sign2 = (sign & 4) != 0;
|
||||
bool sign3 = (sign & 8) != 0;
|
||||
return db * vec4(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0),
|
||||
grid.z * (sign2 ? -1.0 : 1.0),
|
||||
grid.w * (sign3 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_S)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
|
||||
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
|
||||
|
||||
const float db = 0.25 * (0.5 + scale);
|
||||
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
return db * vec2(
|
||||
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
|
||||
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint ib8 = iqs / 8;
|
||||
|
||||
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
|
||||
const uint qs = data_a[a_offset + ib].qs[ib8];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
|
||||
|
||||
const float db = 0.25 * (0.5 + scale);
|
||||
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
bool sign2 = (sign & 4) != 0;
|
||||
bool sign3 = (sign & 8) != 0;
|
||||
return db * vec4(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0),
|
||||
grid.z * (sign2 ? -1.0 : 1.0),
|
||||
grid.w * (sign3 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_XXS)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib4 = iqs / 4;
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint is = QUANT_K / 4 + 4 * ib32;
|
||||
const uint qs = data_a[a_offset + ib].qs[ib4];
|
||||
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
|
||||
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
|
||||
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
|
||||
const float db = 0.5 * (0.5 + (signs >> 28));
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
return db * vec2(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib4 = iqs / 4;
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint is = QUANT_K / 4 + 4 * ib32;
|
||||
const uint qs = data_a[a_offset + ib].qs[ib4];
|
||||
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
|
||||
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
|
||||
const float db = 0.5 * (0.5 + (signs >> 28));
|
||||
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
|
||||
// Add parity bit
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const uint sign = sign8 >> (iqs % 8);
|
||||
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
bool sign2 = (sign & 4) != 0;
|
||||
bool sign3 = (sign & 8) != 0;
|
||||
return db * vec4(
|
||||
grid.x * (sign0 ? -1.0 : 1.0),
|
||||
grid.y * (sign1 ? -1.0 : 1.0),
|
||||
grid.z * (sign2 ? -1.0 : 1.0),
|
||||
grid.w * (sign3 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_S)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
|
||||
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
|
||||
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
|
||||
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
|
||||
return db * vec2(
|
||||
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
|
||||
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint ib4 = iqs / 4;
|
||||
const uint ib32 = iqs / 32;
|
||||
const uint qs = data_a[a_offset + ib].qs[ib4];
|
||||
const uint qh = data_a[a_offset + ib].qh[ib32];
|
||||
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
|
||||
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
|
||||
bool sign0 = (sign & 1) != 0;
|
||||
bool sign1 = (sign & 2) != 0;
|
||||
bool sign2 = (sign & 4) != 0;
|
||||
bool sign3 = (sign & 8) != 0;
|
||||
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
|
||||
return db * vec4(
|
||||
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
|
||||
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
|
||||
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
|
||||
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
@@ -105,7 +321,7 @@ vec2 get_dm(uint ib, uint a_offset) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL)
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(float(data_a[a_offset + ib].d), 0);
|
||||
}
|
||||
|
||||
@@ -301,6 +301,160 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2
|
||||
return ret;
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
|
||||
block_iq2_xxs block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
|
||||
block_iq2_xxs_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
|
||||
const uint ib8 = (idx & 0x18) >> 3; // 0..3
|
||||
const uint iqs = 8 * ib32 + ib8;
|
||||
|
||||
const uint8_t qs = bl.block.qs[iqs];
|
||||
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
|
||||
|
||||
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28));
|
||||
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
|
||||
sign |= bitCount(sign) << 7;
|
||||
|
||||
const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3];
|
||||
|
||||
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_XS)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
|
||||
block_iq2_xs block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..8
|
||||
const uint sshift = (idx & 0x10) >> 2; // 0,4
|
||||
const uint iqs = (idx & 0xF8) >> 3; // 0..63
|
||||
|
||||
const uint16_t qs = bl.block.qs[iqs];
|
||||
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF));
|
||||
|
||||
uint sign = uint(qs >> 9);
|
||||
sign |= bitCount(sign) << 7;
|
||||
const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3];
|
||||
|
||||
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ2_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
|
||||
block_iq2_s block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
uint idx = coordInBlock[1];
|
||||
uint lsb = idx & 1;
|
||||
idx /= 2;
|
||||
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = bl.block.qs[ib8];
|
||||
const uint qh = bl.block.qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
|
||||
return float16_t(v[lsb]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_XXS)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
|
||||
block_iq3_xxs block;
|
||||
};
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
|
||||
block_iq3_xxs_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
uint idx = coordInBlock[1];
|
||||
uint lsb = idx & 1;
|
||||
idx /= 2;
|
||||
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint signs = pack32(u8vec4(
|
||||
bl.block.qs[is+0],
|
||||
bl.block.qs[is+1],
|
||||
bl.block.qs[is+2],
|
||||
bl.block.qs[is+3]
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
return float16_t(v[lsb]);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ3_S)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
|
||||
block_iq3_s block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
uint idx = coordInBlock[1];
|
||||
uint lsb = idx & 1;
|
||||
idx /= 2;
|
||||
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(bl.block.d);
|
||||
const uint qs = bl.block.qs[iqs];
|
||||
const uint qh = bl.block.qh[iqh];
|
||||
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4)));
|
||||
const uint scale = bl.block.scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
|
||||
return float16_t(v[lsb]);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
|
||||
block_iq4_nl block;
|
||||
@@ -340,6 +494,16 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
#define dequantFuncA dequantFuncIQ2_XXS
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
#define dequantFuncA dequantFuncIQ2_XS
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
#define dequantFuncA dequantFuncIQ2_S
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
#define dequantFuncA dequantFuncIQ3_XXS
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
#define dequantFuncA dequantFuncIQ3_S
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
|
||||
const vec2 db = d * (0.5 + scale) * 0.25;
|
||||
|
||||
uint qh = data_a[ib].qh[ib32];
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
|
||||
qs |= (qh << (8 - 2 * l)) & 0x300;
|
||||
const uvec2 grid = iq2s_grid[qs & 511];
|
||||
const u8vec4 grid0 = unpack8(grid.x);
|
||||
const u8vec4 grid1 = unpack8(grid.y);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 subblock (32 values with 2 scales)
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint ib32 = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * ib32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
|
||||
const vec2 db = d * (0.5 + scale) * 0.25;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const u8vec4 grid0 = unpack8(grid.x);
|
||||
const u8vec4 grid1 = unpack8(grid.y);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 scale block (32 values)
|
||||
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint is = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * is;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
uint signscale = pack32(u8vec4(
|
||||
data_a[ib].qs[8*is + 4],
|
||||
data_a[ib].qs[8*is + 5],
|
||||
data_a[ib].qs[8*is + 6],
|
||||
data_a[ib].qs[8*is + 7]
|
||||
));
|
||||
const float db = d * (0.5 + (signscale >> 28)) * 0.25;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
|
||||
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
|
||||
const u8vec4 grid0 = unpack8(grid.x);
|
||||
const u8vec4 grid1 = unpack8(grid.y);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 scale nibble.
|
||||
// Each block contains 4 scale bytes (8 scales) for 256 output values.
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint is = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * is;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
|
||||
|
||||
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
|
||||
uint qh = data_a[ib].qh[is];
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
uint qs = data_a[ib].qs[8 * is + l];
|
||||
uint gidx = qs | ((qh << (8 - l)) & 256);
|
||||
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
|
||||
u8vec4 grid = unpack8(iq3s_grid[gidx]);
|
||||
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
#version 450
|
||||
|
||||
#include "dequant_head.comp"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
|
||||
void main() {
|
||||
// Each thread handles 1 scale block (32 values)
|
||||
// 8 threads handle 1 superblock
|
||||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
|
||||
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
if (ib >= p.nel / 256) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint is = gl_LocalInvocationID.x % 8;
|
||||
const uint b_idx = 256 * ib + 32 * is;
|
||||
const uint s_idx = QUANT_K / 4 + 4 * is;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
uint signscale = pack32(u8vec4(
|
||||
data_a[ib].qs[s_idx + 0],
|
||||
data_a[ib].qs[s_idx + 1],
|
||||
data_a[ib].qs[s_idx + 2],
|
||||
data_a[ib].qs[s_idx + 3]
|
||||
));
|
||||
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
|
||||
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||
// Restore parity bit.
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
|
||||
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||
void main() {
|
||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||
|
||||
init_iq4nl_shmem();
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x % 64;
|
||||
const uint il = tid/32;
|
||||
|
||||
@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
@@ -104,8 +104,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t N = p.N;
|
||||
@@ -166,7 +166,7 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
|
||||
@@ -12,8 +12,8 @@ void main() {
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
if (i00 >= p.ne00) {
|
||||
|
||||
@@ -133,8 +133,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
|
||||
@@ -95,8 +95,8 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -343,10 +343,8 @@ void main() {
|
||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||
|
||||
const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
|
||||
is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
|
||||
is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
|
||||
(data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
|
||||
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||
@@ -439,6 +437,118 @@ void main() {
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[8*ib32 + 4],
|
||||
data_a[ib].qs[8*ib32 + 5],
|
||||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const float db = d * 0.25 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint ib8 = (idx / 4) % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib8 = (idx % 128) / 4; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * 0.25 * (0.5 + scale);
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[is+0],
|
||||
data_a[ib].qs[is+1],
|
||||
data_a[ib].qs[is+2],
|
||||
data_a[ib].qs[is+3]
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
|
||||
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = (idx % 128) / 2; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
|
||||
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
|
||||
|
||||
@@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if QUANT_K > 1
|
||||
#define DECODEFUNCA , dequantFuncA
|
||||
#define MAT_A_TYPE float16_t
|
||||
|
||||
#include "dequant_funcs_cm2.comp"
|
||||
|
||||
#else
|
||||
#define DECODEFUNCA
|
||||
#define MAT_A_TYPE A_TYPE
|
||||
#endif
|
||||
|
||||
#define MAT_B_TYPE B_TYPE
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
@@ -110,8 +106,8 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
init_iq4nl_shmem();
|
||||
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -236,16 +232,13 @@ void main() {
|
||||
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
} else
|
||||
#endif // !defined(MUL_MAT_ID)
|
||||
@@ -261,10 +254,8 @@ void main() {
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
||||
|
||||
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
||||
|
||||
// Clamping is expensive, so detect different code paths for each combination
|
||||
// of A and B needing clamping.
|
||||
@@ -281,16 +272,12 @@ void main() {
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -298,16 +285,12 @@ void main() {
|
||||
#else
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
||||
#endif
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else if (!unclampedA && !unclampedB) {
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
||||
|
||||
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
|
||||
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
|
||||
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,6 +294,738 @@ struct block_q6_K_packed16
|
||||
|
||||
// IQuants
|
||||
|
||||
#define QUANT_K_IQ2_XXS 256
|
||||
#define QUANT_R_IQ2_XXS 1
|
||||
|
||||
struct block_iq2_xxs
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ2_XXS/4];
|
||||
};
|
||||
|
||||
struct block_iq2_xxs_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K_IQ2_XXS/8];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ2_XXS)
|
||||
|
||||
const uvec2[256] iq2xxs_grid_const = {
|
||||
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
|
||||
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808),
|
||||
uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808),
|
||||
uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808),
|
||||
uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808),
|
||||
uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819),
|
||||
uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819),
|
||||
uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b),
|
||||
uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908),
|
||||
uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908),
|
||||
uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908),
|
||||
uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919),
|
||||
uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919),
|
||||
uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b),
|
||||
uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08),
|
||||
uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08),
|
||||
uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08),
|
||||
uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b),
|
||||
uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808),
|
||||
uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808),
|
||||
uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819),
|
||||
uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b),
|
||||
uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908),
|
||||
uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919),
|
||||
uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
|
||||
uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08),
|
||||
uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
|
||||
uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808),
|
||||
uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819),
|
||||
uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908),
|
||||
uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908),
|
||||
uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b),
|
||||
uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b),
|
||||
uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808),
|
||||
uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808),
|
||||
uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808),
|
||||
uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819),
|
||||
uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b),
|
||||
uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908),
|
||||
uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908),
|
||||
uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08),
|
||||
uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08),
|
||||
uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19),
|
||||
uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808),
|
||||
uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808),
|
||||
uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819),
|
||||
uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919),
|
||||
uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08),
|
||||
uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b),
|
||||
uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808),
|
||||
uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b),
|
||||
uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919),
|
||||
uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808),
|
||||
uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819),
|
||||
uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908),
|
||||
uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908),
|
||||
uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919),
|
||||
uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08),
|
||||
uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808),
|
||||
uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819),
|
||||
uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908),
|
||||
uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19),
|
||||
uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819),
|
||||
uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19)
|
||||
};
|
||||
|
||||
shared uvec2 iq2xxs_grid[256];
|
||||
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) {
|
||||
iq2xxs_grid[i] = iq2xxs_grid_const[i];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#define QUANT_K QUANT_K_IQ2_XXS
|
||||
#define QUANT_R QUANT_R_IQ2_XXS
|
||||
#define A_TYPE block_iq2_xxs
|
||||
#define A_TYPE_PACKED16 block_iq2_xxs_packed16
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ2_XS 256
|
||||
#define QUANT_R_IQ2_XS 1
|
||||
|
||||
struct block_iq2_xs
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K_IQ2_XS/8];
|
||||
uint8_t scales[QUANT_K_IQ2_XS/32];
|
||||
};
|
||||
|
||||
struct block_iq2_xs_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K_IQ2_XS/8];
|
||||
uint16_t scales[QUANT_K_IQ2_XS/64];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ2_XS)
|
||||
|
||||
const uvec2 iq2xs_grid_const[512] = {
|
||||
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
|
||||
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
|
||||
uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
|
||||
uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
|
||||
uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
|
||||
uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808),
|
||||
uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808),
|
||||
uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819),
|
||||
uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819),
|
||||
uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819),
|
||||
uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819),
|
||||
uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819),
|
||||
uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819),
|
||||
uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b),
|
||||
uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b),
|
||||
uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
|
||||
uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908),
|
||||
uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908),
|
||||
uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908),
|
||||
uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908),
|
||||
uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908),
|
||||
uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919),
|
||||
uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919),
|
||||
uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
|
||||
uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b),
|
||||
uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b),
|
||||
uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
|
||||
uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08),
|
||||
uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08),
|
||||
uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
|
||||
uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19),
|
||||
uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19),
|
||||
uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b),
|
||||
uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808),
|
||||
uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808),
|
||||
uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808),
|
||||
uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808),
|
||||
uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808),
|
||||
uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819),
|
||||
uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
|
||||
uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819),
|
||||
uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b),
|
||||
uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b),
|
||||
uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908),
|
||||
uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908),
|
||||
uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
|
||||
uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919),
|
||||
uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b),
|
||||
uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08),
|
||||
uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08),
|
||||
uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b),
|
||||
uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808),
|
||||
uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808),
|
||||
uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808),
|
||||
uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819),
|
||||
uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819),
|
||||
uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b),
|
||||
uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908),
|
||||
uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919),
|
||||
uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b),
|
||||
uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08),
|
||||
uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19),
|
||||
uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b),
|
||||
uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808),
|
||||
uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808),
|
||||
uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808),
|
||||
uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808),
|
||||
uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808),
|
||||
uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808),
|
||||
uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819),
|
||||
uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819),
|
||||
uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819),
|
||||
uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b),
|
||||
uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908),
|
||||
uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908),
|
||||
uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908),
|
||||
uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908),
|
||||
uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919),
|
||||
uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b),
|
||||
uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08),
|
||||
uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08),
|
||||
uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19),
|
||||
uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808),
|
||||
uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808),
|
||||
uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808),
|
||||
uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819),
|
||||
uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819),
|
||||
uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b),
|
||||
uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908),
|
||||
uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908),
|
||||
uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919),
|
||||
uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08),
|
||||
uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08),
|
||||
uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b),
|
||||
uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808),
|
||||
uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808),
|
||||
uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b),
|
||||
uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919),
|
||||
uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08),
|
||||
uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b),
|
||||
uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808),
|
||||
uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808),
|
||||
uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
|
||||
uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819),
|
||||
uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819),
|
||||
uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b),
|
||||
uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b),
|
||||
uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
|
||||
uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908),
|
||||
uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b),
|
||||
uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08),
|
||||
uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08),
|
||||
uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b),
|
||||
uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808),
|
||||
uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
|
||||
uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b),
|
||||
uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908),
|
||||
uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919),
|
||||
uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08),
|
||||
uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808),
|
||||
uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808),
|
||||
uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819),
|
||||
uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b),
|
||||
uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908),
|
||||
uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08),
|
||||
uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08),
|
||||
uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19),
|
||||
uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b),
|
||||
};
|
||||
|
||||
shared uvec2 iq2xs_grid[512];
|
||||
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) {
|
||||
iq2xs_grid[i] = iq2xs_grid_const[i];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#define QUANT_K QUANT_K_IQ2_XS
|
||||
#define QUANT_R QUANT_R_IQ2_XS
|
||||
#define A_TYPE block_iq2_xs
|
||||
#define A_TYPE_PACKED16 block_iq2_xs_packed16
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ2_S 256
|
||||
#define QUANT_R_IQ2_S 1
|
||||
|
||||
struct block_iq2_s
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ2_S/4];
|
||||
uint8_t qh[QUANT_K_IQ2_S/32];
|
||||
uint8_t scales[QUANT_K_IQ2_S/32];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ2_S)
|
||||
|
||||
const uvec2 iq2s_grid_const[1024] = {
|
||||
uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808),
|
||||
uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808),
|
||||
uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808),
|
||||
uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808),
|
||||
uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808),
|
||||
uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808),
|
||||
uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808),
|
||||
uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808),
|
||||
uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819),
|
||||
uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819),
|
||||
uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819),
|
||||
uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819),
|
||||
uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819),
|
||||
uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819),
|
||||
uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819),
|
||||
uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b),
|
||||
uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b),
|
||||
uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b),
|
||||
uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b),
|
||||
uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b),
|
||||
uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908),
|
||||
uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908),
|
||||
uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908),
|
||||
uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908),
|
||||
uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908),
|
||||
uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908),
|
||||
uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908),
|
||||
uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908),
|
||||
uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919),
|
||||
uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919),
|
||||
uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919),
|
||||
uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919),
|
||||
uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919),
|
||||
uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919),
|
||||
uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919),
|
||||
uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b),
|
||||
uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b),
|
||||
uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b),
|
||||
uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b),
|
||||
uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08),
|
||||
uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08),
|
||||
uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08),
|
||||
uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08),
|
||||
uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08),
|
||||
uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08),
|
||||
uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19),
|
||||
uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19),
|
||||
uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19),
|
||||
uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19),
|
||||
uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b),
|
||||
uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b),
|
||||
uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808),
|
||||
uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808),
|
||||
uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808),
|
||||
uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808),
|
||||
uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808),
|
||||
uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808),
|
||||
uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808),
|
||||
uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808),
|
||||
uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819),
|
||||
uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819),
|
||||
uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819),
|
||||
uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819),
|
||||
uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819),
|
||||
uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819),
|
||||
uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819),
|
||||
uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b),
|
||||
uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b),
|
||||
uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b),
|
||||
uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b),
|
||||
uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908),
|
||||
uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908),
|
||||
uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908),
|
||||
uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908),
|
||||
uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908),
|
||||
uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908),
|
||||
uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908),
|
||||
uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919),
|
||||
uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919),
|
||||
uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919),
|
||||
uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919),
|
||||
uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919),
|
||||
uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b),
|
||||
uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b),
|
||||
uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08),
|
||||
uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08),
|
||||
uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08),
|
||||
uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08),
|
||||
uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08),
|
||||
uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19),
|
||||
uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19),
|
||||
uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19),
|
||||
uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b),
|
||||
uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808),
|
||||
uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808),
|
||||
uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808),
|
||||
uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808),
|
||||
uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808),
|
||||
uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819),
|
||||
uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819),
|
||||
uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819),
|
||||
uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819),
|
||||
uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b),
|
||||
uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b),
|
||||
uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908),
|
||||
uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908),
|
||||
uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908),
|
||||
uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908),
|
||||
uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908),
|
||||
uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919),
|
||||
uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919),
|
||||
uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919),
|
||||
uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b),
|
||||
uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08),
|
||||
uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08),
|
||||
uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19),
|
||||
uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b),
|
||||
uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808),
|
||||
uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808),
|
||||
uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808),
|
||||
uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808),
|
||||
uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808),
|
||||
uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808),
|
||||
uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808),
|
||||
uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808),
|
||||
uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819),
|
||||
uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819),
|
||||
uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819),
|
||||
uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819),
|
||||
uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819),
|
||||
uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819),
|
||||
uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819),
|
||||
uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b),
|
||||
uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b),
|
||||
uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b),
|
||||
uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b),
|
||||
uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908),
|
||||
uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908),
|
||||
uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908),
|
||||
uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908),
|
||||
uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908),
|
||||
uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908),
|
||||
uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908),
|
||||
uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919),
|
||||
uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919),
|
||||
uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919),
|
||||
uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919),
|
||||
uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919),
|
||||
uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919),
|
||||
uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b),
|
||||
uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b),
|
||||
uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b),
|
||||
uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08),
|
||||
uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08),
|
||||
uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08),
|
||||
uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08),
|
||||
uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19),
|
||||
uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19),
|
||||
uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19),
|
||||
uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b),
|
||||
uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808),
|
||||
uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808),
|
||||
uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808),
|
||||
uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808),
|
||||
uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808),
|
||||
uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808),
|
||||
uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808),
|
||||
uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819),
|
||||
uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819),
|
||||
uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819),
|
||||
uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819),
|
||||
uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819),
|
||||
uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b),
|
||||
uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b),
|
||||
uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b),
|
||||
uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908),
|
||||
uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908),
|
||||
uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908),
|
||||
uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908),
|
||||
uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908),
|
||||
uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919),
|
||||
uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919),
|
||||
uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919),
|
||||
uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b),
|
||||
uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08),
|
||||
uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08),
|
||||
uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08),
|
||||
uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19),
|
||||
uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b),
|
||||
uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808),
|
||||
uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808),
|
||||
uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808),
|
||||
uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808),
|
||||
uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819),
|
||||
uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819),
|
||||
uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819),
|
||||
uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b),
|
||||
uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908),
|
||||
uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908),
|
||||
uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908),
|
||||
uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919),
|
||||
uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919),
|
||||
uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08),
|
||||
uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19),
|
||||
uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808),
|
||||
uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808),
|
||||
uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808),
|
||||
uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808),
|
||||
uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808),
|
||||
uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819),
|
||||
uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819),
|
||||
uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819),
|
||||
uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819),
|
||||
uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819),
|
||||
uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b),
|
||||
uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b),
|
||||
uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908),
|
||||
uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908),
|
||||
uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908),
|
||||
uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908),
|
||||
uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908),
|
||||
uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919),
|
||||
uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919),
|
||||
uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919),
|
||||
uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b),
|
||||
uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08),
|
||||
uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08),
|
||||
uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19),
|
||||
uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b),
|
||||
uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808),
|
||||
uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808),
|
||||
uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808),
|
||||
uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808),
|
||||
uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808),
|
||||
uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819),
|
||||
uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819),
|
||||
uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b),
|
||||
uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908),
|
||||
uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908),
|
||||
uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908),
|
||||
uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919),
|
||||
uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919),
|
||||
uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08),
|
||||
uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08),
|
||||
uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19),
|
||||
uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808),
|
||||
uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808),
|
||||
uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808),
|
||||
uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b),
|
||||
uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b),
|
||||
uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908),
|
||||
uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908),
|
||||
uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b),
|
||||
uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19),
|
||||
uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b),
|
||||
uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b)
|
||||
};
|
||||
|
||||
shared uvec2 iq2s_grid[1024];
|
||||
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) {
|
||||
iq2s_grid[i] = iq2s_grid_const[i];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#define QUANT_K QUANT_K_IQ2_S
|
||||
#define QUANT_R QUANT_R_IQ2_S
|
||||
#define A_TYPE block_iq2_s
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ3_XXS 256
|
||||
#define QUANT_R_IQ3_XXS 1
|
||||
|
||||
struct block_iq3_xxs
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8];
|
||||
};
|
||||
|
||||
struct block_iq3_xxs_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ3_XXS)
|
||||
|
||||
const uint32_t iq3xxs_grid_const[256] = {
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
|
||||
0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
|
||||
0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
|
||||
0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
|
||||
0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
|
||||
0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
|
||||
0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
|
||||
0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
|
||||
0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
|
||||
0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
|
||||
0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
|
||||
0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
|
||||
0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
|
||||
0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
|
||||
0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
|
||||
0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
|
||||
0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
|
||||
0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
|
||||
0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
|
||||
0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
|
||||
0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
|
||||
0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
|
||||
0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
|
||||
0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
|
||||
0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
|
||||
0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
|
||||
0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
|
||||
0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
|
||||
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
||||
};
|
||||
|
||||
shared uint32_t iq3xxs_grid[256];
|
||||
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) {
|
||||
iq3xxs_grid[i] = iq3xxs_grid_const[i];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#define QUANT_K QUANT_K_IQ3_XXS
|
||||
#define QUANT_R QUANT_R_IQ3_XXS
|
||||
#define A_TYPE block_iq3_xxs
|
||||
#define A_TYPE_PACKED16 block_iq3_xxs_packed16
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ3_S 256
|
||||
#define QUANT_R_IQ3_S 1
|
||||
|
||||
struct block_iq3_s
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[QUANT_K_IQ3_S/4];
|
||||
uint8_t qh[QUANT_K_IQ3_S/32];
|
||||
uint8_t signs[QUANT_K_IQ3_S/8];
|
||||
uint8_t scales[QUANT_K_IQ3_S/64];
|
||||
};
|
||||
|
||||
struct block_iq3_s_packed16
|
||||
{
|
||||
float16_t d;
|
||||
uint16_t qs[QUANT_K_IQ3_S/4/2];
|
||||
uint16_t qh[QUANT_K_IQ3_S/32/2];
|
||||
uint16_t signs[QUANT_K_IQ3_S/8/2];
|
||||
uint16_t scales[QUANT_K_IQ3_S/64/2];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_IQ3_S)
|
||||
|
||||
const uint32_t iq3s_grid_const[512] = {
|
||||
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
||||
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
||||
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
||||
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
||||
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
||||
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
||||
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
||||
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
||||
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
||||
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
||||
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
||||
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
||||
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
||||
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
||||
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
||||
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
||||
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
||||
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
||||
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
||||
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
||||
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
||||
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
||||
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
||||
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
||||
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
||||
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
||||
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
||||
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
||||
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
||||
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
||||
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
||||
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
||||
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
||||
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
||||
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
||||
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
||||
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
||||
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
||||
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
||||
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
||||
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
||||
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
||||
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
||||
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
||||
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
||||
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
||||
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
||||
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
||||
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
||||
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
||||
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
||||
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
||||
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
||||
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
||||
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
||||
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
||||
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
||||
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
||||
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
||||
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
||||
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
||||
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
||||
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||
};
|
||||
|
||||
shared uint32_t iq3s_grid[512];
|
||||
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) {
|
||||
iq3s_grid[i] = iq3s_grid_const[i];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
#define QUANT_K QUANT_K_IQ3_S
|
||||
#define QUANT_R QUANT_R_IQ3_S
|
||||
#define A_TYPE block_iq3_s
|
||||
#define A_TYPE_PACKED16 block_iq3_s_packed16
|
||||
#endif
|
||||
|
||||
#define QUANT_K_IQ4_NL 32
|
||||
#define QUANT_R_IQ4_NL 2
|
||||
|
||||
@@ -318,11 +1050,11 @@ const int8_t kvalues_iq4nl_const[16] = {
|
||||
|
||||
shared FLOAT_TYPE kvalues_iq4nl[16];
|
||||
|
||||
void init_iq4nl_shmem()
|
||||
void init_iq_shmem(uvec3 wgsize)
|
||||
{
|
||||
// copy the table into shared memory and sync
|
||||
if (gl_LocalInvocationIndex.x < 16) {
|
||||
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
|
||||
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) {
|
||||
kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#include <direct.h> // For _mkdir on Windows
|
||||
#include <algorithm> // For std::replace on w64devkit
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
@@ -55,6 +55,11 @@ const std::vector<std::string> type_names = {
|
||||
"q4_k",
|
||||
"q5_k",
|
||||
"q6_k",
|
||||
"iq2_xxs",
|
||||
"iq2_xs",
|
||||
"iq2_s",
|
||||
"iq3_xxs",
|
||||
"iq3_s",
|
||||
"iq4_nl"
|
||||
};
|
||||
|
||||
@@ -316,8 +321,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
|
||||
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
@@ -499,6 +507,7 @@ void write_output_files() {
|
||||
fprintf(hdr, "#include <cstdint>\n\n");
|
||||
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
|
||||
|
||||
std::sort(shader_fnames.begin(), shader_fnames.end());
|
||||
for (const auto& pair : shader_fnames) {
|
||||
const std::string& name = pair.first;
|
||||
#ifdef _WIN32
|
||||
|
||||
+23
-13
@@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) {
|
||||
#endif
|
||||
|
||||
static void ggml_print_backtrace(void) {
|
||||
const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
|
||||
if (GGML_NO_BACKTRACE) {
|
||||
return;
|
||||
}
|
||||
char attach[32];
|
||||
snprintf(attach, sizeof(attach), "attach %d", getpid());
|
||||
int pid = fork();
|
||||
@@ -5339,7 +5343,7 @@ static void ggml_compute_backward(
|
||||
} break;
|
||||
case GGML_OP_MUL: {
|
||||
if (src0_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
|
||||
@@ -5431,21 +5435,25 @@ static void ggml_compute_backward(
|
||||
// src1.shape [n,p,qq,rr]
|
||||
|
||||
if (src0_needs_grads) {
|
||||
struct ggml_tensor * s1_tg =
|
||||
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
|
||||
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
|
||||
struct ggml_tensor * tmp =
|
||||
ggml_out_prod(ctx, // [n,m,qq,rr]
|
||||
src1, // [n,p,qq,rr]
|
||||
grad); // [m,p,qq,rr]
|
||||
const int64_t qq = s1_tg->ne[2];
|
||||
const int64_t rr = s1_tg->ne[3];
|
||||
const int64_t q1 = src0->ne[2];
|
||||
const int64_t r1 = src0->ne[3];
|
||||
const bool ne2_broadcasted = qq > q1;
|
||||
const bool ne3_broadcasted = rr > r1;
|
||||
if (ne2_broadcasted || ne3_broadcasted) {
|
||||
// sum broadcast repetitions of s1_tg into shape of src0
|
||||
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
|
||||
if (!ggml_are_same_shape(tmp, src0)) {
|
||||
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
|
||||
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
|
||||
GGML_ASSERT(tmp->ne[3] == 1);
|
||||
|
||||
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
|
||||
const size_t nb2 = tmp->nb[2] * nr2;
|
||||
const size_t nb3 = tmp->nb[2];
|
||||
|
||||
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
|
||||
tmp = ggml_repeat_back(ctx, tmp, src0);
|
||||
}
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc1,
|
||||
@@ -5514,7 +5522,9 @@ static void ggml_compute_backward(
|
||||
if (src0_needs_grads) {
|
||||
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(grad));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, grad);
|
||||
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
|
||||
ggml_add_or_set(ctx, cgraph, isrc0,
|
||||
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RESHAPE: {
|
||||
|
||||
+2
-1
@@ -510,7 +510,8 @@ extern "C" {
|
||||
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
|
||||
|
||||
// Get the default chat template. Returns nullptr if not available
|
||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
|
||||
// If name is NULL, returns the default chat template
|
||||
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
|
||||
|
||||
// Returns the total number of parameters in the model
|
||||
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
|
||||
|
||||
Executable
+77
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
Fetches the Jinja chat template of a HuggingFace model.
|
||||
If a model has multiple chat templates, you can specify the variant name.
|
||||
|
||||
Syntax:
|
||||
./scripts/get_hf_chat_template.py model_id [variant]
|
||||
|
||||
Examples:
|
||||
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
|
||||
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
|
||||
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
|
||||
'''
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
|
||||
def get_hf_chat_template(model_id, variant=None):
|
||||
try:
|
||||
# Use huggingface_hub library if available.
|
||||
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
|
||||
from huggingface_hub import hf_hub_download
|
||||
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
|
||||
config_str = f.read()
|
||||
except ImportError:
|
||||
import requests
|
||||
assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
|
||||
response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
|
||||
if response.status_code == 401:
|
||||
raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
|
||||
response.raise_for_status()
|
||||
config_str = response.text
|
||||
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError:
|
||||
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
|
||||
# (Remove extra '}' near the end of the file)
|
||||
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
|
||||
|
||||
chat_template = config['chat_template']
|
||||
if isinstance(chat_template, str):
|
||||
return chat_template
|
||||
else:
|
||||
variants = {
|
||||
ct['name']: ct['template']
|
||||
for ct in chat_template
|
||||
}
|
||||
|
||||
def format_variants():
|
||||
return ', '.join(f'"{v}"' for v in variants.keys())
|
||||
|
||||
if variant is None:
|
||||
if 'default' not in variants:
|
||||
raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
|
||||
variant = 'default'
|
||||
sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
|
||||
elif variant not in variants:
|
||||
raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
|
||||
|
||||
return variants[variant]
|
||||
|
||||
|
||||
def main(args):
|
||||
if len(args) < 1:
|
||||
raise ValueError("Please provide a model ID and an optional variant name")
|
||||
model_id = args[0]
|
||||
variant = None if len(args) < 2 else args[1]
|
||||
|
||||
template = get_hf_chat_template(model_id, variant)
|
||||
sys.stdout.write(template)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(sys.argv[1:])
|
||||
@@ -1 +1 @@
|
||||
d92321c0d151fe73a47d89738c7c3091ac904297
|
||||
32f0b85987396945afea2291d5f4c5862434292b
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ add_library(llama
|
||||
unicode-data.cpp
|
||||
)
|
||||
|
||||
target_include_directories(llama PUBLIC . ../include)
|
||||
target_include_directories(llama PUBLIC . ../include ../common)
|
||||
target_compile_features (llama PUBLIC cxx_std_17) # don't bump
|
||||
|
||||
target_link_libraries(llama PUBLIC ggml)
|
||||
|
||||
+4
-2
@@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
|
||||
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
|
||||
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
|
||||
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
|
||||
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
|
||||
@@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
std::string LLM_KV::operator()(llm_kv kv) const {
|
||||
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
|
||||
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
|
||||
}
|
||||
|
||||
std::string LLM_TN_IMPL::str() const {
|
||||
|
||||
+3
-1
@@ -177,6 +177,7 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_HF_JSON,
|
||||
LLM_KV_TOKENIZER_RWKV,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
|
||||
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
|
||||
LLM_KV_TOKENIZER_FIM_PRE_ID,
|
||||
LLM_KV_TOKENIZER_FIM_SUF_ID,
|
||||
LLM_KV_TOKENIZER_FIM_MID_ID,
|
||||
@@ -335,9 +336,10 @@ enum llm_tensor_layer {
|
||||
};
|
||||
|
||||
struct LLM_KV {
|
||||
LLM_KV(llm_arch arch);
|
||||
LLM_KV(llm_arch arch, const char * suffix = nullptr);
|
||||
|
||||
llm_arch arch;
|
||||
const char * suffix;
|
||||
|
||||
std::string operator()(llm_kv kv) const;
|
||||
};
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
#include <stdexcept>
|
||||
#include <cerrno>
|
||||
|
||||
#ifdef __has_include
|
||||
#if __has_include(<unistd.h>)
|
||||
|
||||
@@ -819,7 +819,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
|
||||
for (const auto & file : files) {
|
||||
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
|
||||
auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
|
||||
std::unique_ptr<llama_mmap> mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn()));
|
||||
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
|
||||
mmaps_used.emplace_back(mapping->size(), 0);
|
||||
if (mlock_mmaps) {
|
||||
std::unique_ptr<llama_mlock> mlock_mmap(new llama_mlock());
|
||||
|
||||
+6
-2
@@ -1303,10 +1303,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
|
||||
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
|
||||
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev));
|
||||
return {cpu_dev, &pimpl->cpu_buft_list};
|
||||
}
|
||||
const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
|
||||
auto * dev = devices.at(layer_gpu);
|
||||
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev));
|
||||
return {dev, &pimpl->gpu_buft_list.at(dev)};
|
||||
};
|
||||
|
||||
@@ -3955,8 +3957,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
|
||||
return model->size();
|
||||
}
|
||||
|
||||
const char * llama_model_chat_template(const struct llama_model * model) {
|
||||
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
|
||||
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
|
||||
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
|
||||
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
|
||||
const auto & it = model->gguf_kv.find(key);
|
||||
if (it == model->gguf_kv.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
+8
-3
@@ -1245,8 +1245,13 @@ struct llama_vocab::impl {
|
||||
|
||||
std::vector<llama_token> cache_special_tokens;
|
||||
std::vector<std::string> cache_token_to_piece; // llama_token_to_piece(special = true);
|
||||
|
||||
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
||||
struct pair_hash {
|
||||
size_t operator()(const std::pair<std::string, std::string> & p) const {
|
||||
return std::hash<std::string>{}(p.first) ^ //create some hash for pair
|
||||
(std::hash<std::string>{}(p.second) << 1);
|
||||
}
|
||||
};
|
||||
std::unordered_map<std::pair<std::string, std::string>, int, pair_hash> bpe_ranks;
|
||||
|
||||
// set of all tokens that cause "end of generation"
|
||||
std::set<llama_token> special_eog_ids;
|
||||
@@ -1687,7 +1692,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
linefeed_id = ids[0];
|
||||
} else {
|
||||
const std::vector<int> ids = tokenize("\xC4\x8A", false); // U+010A
|
||||
const std::vector<int> ids = tokenize("\n", false);
|
||||
|
||||
//GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
|
||||
if (ids.empty()) {
|
||||
|
||||
+154
-111
@@ -7700,17 +7700,13 @@ struct llm_build_context {
|
||||
1
|
||||
);
|
||||
|
||||
struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_states,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self.v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
||||
)
|
||||
ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
|
||||
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
|
||||
)
|
||||
);
|
||||
|
||||
@@ -8432,13 +8428,141 @@ static enum ggml_status llama_graph_compute(
|
||||
return status;
|
||||
}
|
||||
|
||||
static int llama_prepare_sbatch(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
uint32_t & n_outputs) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !lctx.kv_self.recurrent,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int llama_prepare_ubatch(
|
||||
llama_context & lctx,
|
||||
llama_kv_slot_restorer & kv_slot_restorer,
|
||||
llama_ubatch & ubatch,
|
||||
const uint32_t n_outputs,
|
||||
const uint32_t n_tokens_all) {
|
||||
GGML_ASSERT(lctx.sbatch.n_tokens > 0);
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (lctx.kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
|
||||
}
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += int32_t(ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
// in case of unsuccessful decoding (error or warning),
|
||||
// the kv_cache state will be returned to its original state
|
||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
// - inp_batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
@@ -8455,37 +8579,18 @@ static int llama_decode_impl(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
// temporarily allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
|
||||
@@ -8495,99 +8600,27 @@ static int llama_decode_impl(
|
||||
uint32_t n_outputs = 0;
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
{
|
||||
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs = 1;
|
||||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
|
||||
return -2;
|
||||
};
|
||||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs == n_tokens_all) {
|
||||
n_outputs_new = n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
|
||||
ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
|
||||
|
||||
GGML_ASSERT(n_threads > 0);
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_cache_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = llama_kv_cache_get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
@@ -8640,7 +8673,7 @@ static int llama_decode_impl(
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += n_tokens;
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
@@ -9405,6 +9438,7 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
model->devices.push_back(*dev);
|
||||
}
|
||||
} else {
|
||||
std::vector<ggml_backend_dev_t> rpc_servers;
|
||||
// use all available devices
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||
@@ -9415,10 +9449,19 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||
break;
|
||||
|
||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
model->devices.push_back(dev);
|
||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
||||
rpc_servers.push_back(dev);
|
||||
} else {
|
||||
model->devices.push_back(dev);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
// add RPC servers at the front of the list
|
||||
if (!rpc_servers.empty()) {
|
||||
model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
|
||||
}
|
||||
}
|
||||
|
||||
// if using single GPU mode, remove all except the main GPU
|
||||
|
||||
+130
-45
@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_REPEAT_BACK
|
||||
struct test_repeat_back : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int, 4> nr;
|
||||
const bool v; // whether src is a noncontiguous view
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, nr, v);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
return ggml_nbytes(t) * 2;
|
||||
}
|
||||
|
||||
test_repeat_back(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {8, 6, 4, 2},
|
||||
std::array<int, 4> nr = {2, 2, 2, 2},
|
||||
bool v = false)
|
||||
: type(type), ne(ne), nr(nr), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
|
||||
ggml_set_name(src, "src");
|
||||
|
||||
if (v) {
|
||||
GGML_ASSERT(ne[0] % 2 == 0);
|
||||
GGML_ASSERT(ne[1] % 2 == 0);
|
||||
GGML_ASSERT(ne[2] % 2 == 0);
|
||||
GGML_ASSERT(ne[3] % 2 == 0);
|
||||
GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
|
||||
GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
|
||||
GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
|
||||
GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
|
||||
|
||||
const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
|
||||
const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
|
||||
const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
|
||||
const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
|
||||
|
||||
src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
|
||||
}
|
||||
|
||||
ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(target, "target");
|
||||
|
||||
ggml_tensor * out = ggml_repeat_back(ctx, src, target);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_DUP
|
||||
struct test_dup : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
int64_t grad_nmax() override {
|
||||
return 20000;
|
||||
}
|
||||
|
||||
uint64_t op_flops(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
|
||||
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
|
||||
|
||||
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
if (!ggml_is_quantized(type_a)) {
|
||||
if (bs[1] == 1 && nr[1] == 1) {
|
||||
ggml_set_param(ctx, a);
|
||||
}
|
||||
ggml_set_param(ctx, b);
|
||||
}
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
|
||||
} else {
|
||||
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
|
||||
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
|
||||
ggml_set_param(ctx, a);
|
||||
ggml_set_param(ctx, b);
|
||||
if (!ggml_is_quantized(type_a)) {
|
||||
if (bs[1] == 1 && nr[1] == 1) {
|
||||
ggml_set_param(ctx, a);
|
||||
}
|
||||
ggml_set_param(ctx, b);
|
||||
}
|
||||
ggml_set_name(a, "a");
|
||||
ggml_set_name(b, "b");
|
||||
}
|
||||
@@ -2282,11 +2347,12 @@ struct test_soft_max : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool mask;
|
||||
const ggml_type m_prec;
|
||||
const float scale;
|
||||
const float max_bias;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
|
||||
return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
|
||||
}
|
||||
|
||||
// the 1024 test with bias occasionally fails:
|
||||
@@ -2298,9 +2364,10 @@ struct test_soft_max : public test_case {
|
||||
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||
bool mask = false,
|
||||
ggml_type m_prec = GGML_TYPE_F32,
|
||||
float scale = 1.0f,
|
||||
float max_bias = 0.0f)
|
||||
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
|
||||
: type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
@@ -2309,7 +2376,7 @@ struct test_soft_max : public test_case {
|
||||
|
||||
ggml_tensor * mask = nullptr;
|
||||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
||||
mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
|
||||
ggml_set_name(mask, "mask");
|
||||
}
|
||||
|
||||
@@ -3798,6 +3865,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
|
||||
}
|
||||
|
||||
for (bool view : {false, true}) {
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
|
||||
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
|
||||
@@ -3909,38 +3986,35 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
for (int i = 1; i < 9; ++i) {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q6_K, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_IQ4_NL, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (int i = 1; i < 10; ++i) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
// test cases without permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
|
||||
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
|
||||
|
||||
// test cases with permutation
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
|
||||
@@ -4078,17 +4152,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
for (int64_t ne1 : {16, 1024}) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
|
||||
if (mask) {
|
||||
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
|
||||
}
|
||||
} else {
|
||||
/* The precision of mask here doesn't matter as boolean mask is false */
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
|
||||
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
@@ -4224,13 +4309,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
|
||||
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user