mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
51 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 24bc238303 | |||
| 16639ba217 | |||
| 9981c30130 | |||
| e9fd8dcab4 | |||
| 4e5b83b226 | |||
| bb02f74c61 | |||
| 8f91ca54ec | |||
| 81ab64f3c8 | |||
| 8af1f5f430 | |||
| 557515be1e | |||
| cb6caca191 | |||
| b5b8fa1c8b | |||
| a14b960bc7 | |||
| 091a46cb8d | |||
| a3e812811d | |||
| 51fa458a92 | |||
| a5eaa1d6a3 | |||
| e2baf02162 | |||
| e34d6d03b2 | |||
| 9c96465f99 | |||
| 4e595b250a | |||
| 0e4ebeb057 | |||
| 8b30840703 | |||
| 9eb5bfec1a | |||
| c6926d1d95 | |||
| b70d251076 | |||
| 5516b9c16a | |||
| 94242a62c0 | |||
| 6b99a223e3 | |||
| 77078e80e5 | |||
| c301172f66 | |||
| 3802d3c78f | |||
| 9da3dcd753 | |||
| bd544c94a3 | |||
| 14be5a39b1 | |||
| fbbf3ad190 | |||
| 33f890e579 | |||
| 067b8d7af3 | |||
| 50b7f076a5 | |||
| ad8d85bd94 | |||
| 12a4a47e6a | |||
| 37c35f0e1c | |||
| 5bd341c9a1 | |||
| 1c7cf94b22 | |||
| 2c1f199653 | |||
| d1e3556481 | |||
| 08f3f4a8a3 | |||
| 271191906c | |||
| 7dee9ff59a | |||
| 6df686bee6 | |||
| 1706a6d7c6 |
@@ -16,7 +16,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -47,10 +47,10 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
@@ -73,10 +73,10 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
|
||||
@@ -7,7 +7,7 @@ jobs:
|
||||
linux:
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Arm64
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture arm64
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
@@ -201,7 +201,7 @@ jobs:
|
||||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
@@ -262,10 +262,10 @@ jobs:
|
||||
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Use SpacemiT Toolchain Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
|
||||
+57
-57
@@ -63,7 +63,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -135,7 +135,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -317,7 +317,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -380,7 +380,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -414,7 +414,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -436,7 +436,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -472,7 +472,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -494,7 +494,7 @@ jobs:
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
@@ -543,7 +543,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -585,7 +585,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -616,7 +616,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
@@ -644,7 +644,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
@@ -668,7 +668,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -693,7 +693,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
@@ -717,7 +717,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -749,7 +749,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -781,7 +781,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -813,7 +813,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -843,7 +843,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -853,7 +853,7 @@ jobs:
|
||||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Download xcframework artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
@@ -885,7 +885,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -954,7 +954,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1053,7 +1053,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
@@ -1092,7 +1092,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1145,7 +1145,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1177,7 +1177,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
@@ -1187,7 +1187,7 @@ jobs:
|
||||
7z x data.tar
|
||||
|
||||
- name: Use ROCm Installation Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
@@ -1239,7 +1239,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
@@ -1269,7 +1269,7 @@ jobs:
|
||||
./build-xcframework.sh
|
||||
|
||||
- name: Upload xcframework artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
@@ -1285,7 +1285,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Disabled due to size (400MB) and always 0 cache hits
|
||||
# - name: ccache
|
||||
@@ -1295,7 +1295,7 @@ jobs:
|
||||
# evict-old-files: 1d
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v5
|
||||
with:
|
||||
java-version: 17
|
||||
distribution: zulu
|
||||
@@ -1327,7 +1327,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install OpenCL Headers and Libs
|
||||
id: install_opencl
|
||||
@@ -1402,7 +1402,7 @@ jobs:
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -1460,7 +1460,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1486,7 +1486,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1512,7 +1512,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1538,7 +1538,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1564,7 +1564,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1590,7 +1590,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1604,7 +1604,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1618,7 +1618,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1632,7 +1632,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1645,7 +1645,7 @@ jobs:
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
@@ -1659,7 +1659,7 @@ jobs:
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
@@ -1673,7 +1673,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1686,7 +1686,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
@@ -1714,7 +1714,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
@@ -1728,7 +1728,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -1773,7 +1773,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check environment
|
||||
run: |
|
||||
@@ -1875,7 +1875,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -1969,7 +1969,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -2043,7 +2043,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
@@ -2089,7 +2089,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
||||
@@ -23,12 +23,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ jobs:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
|
||||
@@ -26,7 +26,7 @@ jobs:
|
||||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
sudo chmod +x /usr/local/bin/git-clang-format
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # preserve git history, so we can determine the build number
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
editorconfig:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: editorconfig-checker/action-editorconfig-checker@v2
|
||||
with:
|
||||
version: v3.0.3
|
||||
|
||||
@@ -24,9 +24,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.9.x'
|
||||
- name: Install dependencies
|
||||
|
||||
@@ -9,9 +9,9 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: "ggml-org/llama.cpp"
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
configuration-path: '.github/labeler.yml'
|
||||
|
||||
@@ -16,10 +16,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ jobs:
|
||||
name: check-requirements
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Run check-requirements.sh script
|
||||
|
||||
@@ -19,9 +19,9 @@ jobs:
|
||||
name: Lint
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: flake8 Lint
|
||||
|
||||
@@ -24,9 +24,9 @@ jobs:
|
||||
name: pyright type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Python dependencies
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
|
||||
name: llama-bin-macos-arm64.tar.gz
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -111,7 +111,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
|
||||
name: llama-bin-macos-x64.tar.gz
|
||||
@@ -133,7 +133,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -173,7 +173,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -226,7 +226,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
|
||||
name: llama-bin-ubuntu-vulkan-x64.tar.gz
|
||||
@@ -242,7 +242,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -278,7 +278,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
@@ -305,7 +305,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -360,7 +360,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
@@ -375,7 +375,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -416,7 +416,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
@@ -431,7 +431,7 @@ jobs:
|
||||
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
|
||||
|
||||
- name: Upload Cuda runtime
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
@@ -451,7 +451,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
@@ -511,7 +511,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload the release package
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
@@ -531,7 +531,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
@@ -542,7 +542,7 @@ jobs:
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
@@ -617,7 +617,7 @@ jobs:
|
||||
7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
@@ -627,7 +627,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -672,7 +672,7 @@ jobs:
|
||||
zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
@@ -703,7 +703,7 @@ jobs:
|
||||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -763,7 +763,7 @@ jobs:
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
@@ -794,7 +794,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -804,7 +804,7 @@ jobs:
|
||||
|
||||
- name: Download artifacts
|
||||
id: download-artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: ./artifact
|
||||
merge-multiple: true
|
||||
@@ -887,7 +887,7 @@ jobs:
|
||||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
script: |
|
||||
@@ -897,7 +897,7 @@ jobs:
|
||||
for (let file of await fs.readdirSync('./release')) {
|
||||
if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) {
|
||||
console.log('uploadReleaseAsset', file);
|
||||
await github.repos.uploadReleaseAsset({
|
||||
await github.rest.repos.uploadReleaseAsset({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
release_id: release_id,
|
||||
|
||||
@@ -37,14 +37,14 @@ jobs:
|
||||
continue-on-error: true
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
id: node
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
@@ -131,14 +131,14 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
@@ -72,12 +72,12 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
@@ -108,12 +108,12 @@ jobs:
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const { data: releases } = await github.rest.repos.listReleases({
|
||||
|
||||
+25
-21
@@ -1231,6 +1231,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
[](common_params & params, int value) {
|
||||
params.n_ctx = value;
|
||||
if (value == 0) {
|
||||
// disable context reduction in llama_params_fit if the user explicitly requests the full context size:
|
||||
params.fit_params_min_ctx = UINT32_MAX;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_SIZE"));
|
||||
add_opt(common_arg(
|
||||
@@ -1573,7 +1577,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
||||
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.temp = std::stof(value);
|
||||
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
||||
@@ -1590,7 +1594,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
|
||||
add_opt(common_arg(
|
||||
{"--top-p"}, "N",
|
||||
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
string_format("top-p sampling (default: %.2f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
|
||||
@@ -1598,7 +1602,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--min-p"}, "N",
|
||||
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
string_format("min-p sampling (default: %.2f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.min_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
|
||||
@@ -1606,14 +1610,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
string_format("xtc probability (default: %.2f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_probability = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
|
||||
@@ -1621,7 +1625,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-threshold"}, "N",
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
string_format("xtc threshold (default: %.2f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_threshold = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
|
||||
@@ -1629,7 +1633,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.typ_p = std::stof(value);
|
||||
}
|
||||
@@ -1648,7 +1652,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--repeat-penalty"}, "N",
|
||||
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
string_format("penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_repeat = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
|
||||
@@ -1656,21 +1660,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--presence-penalty"}, "N",
|
||||
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
string_format("repeat alpha presence penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_present = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--frequency-penalty"}, "N",
|
||||
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
string_format("repeat alpha frequency penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_freq = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dry-multiplier"}, "N",
|
||||
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
string_format("set DRY sampling multiplier (default: %.2f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dry_multiplier = std::stof(value);
|
||||
}
|
||||
@@ -1751,14 +1755,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-range"}, "N",
|
||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
string_format("dynamic temperature range (default: %.2f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_range = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-exp"}, "N",
|
||||
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
|
||||
string_format("dynamic temperature exponent (default: %.2f)", (double)params.sampling.dynatemp_exponent),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_exponent = std::stof(value);
|
||||
}
|
||||
@@ -1774,7 +1778,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-lr"}, "N",
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.2f)", (double)params.sampling.mirostat_eta),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_eta = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
|
||||
@@ -1782,7 +1786,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-ent"}, "N",
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.2f)", (double)params.sampling.mirostat_tau),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_tau = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
|
||||
@@ -1916,28 +1920,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-ext-factor"}, "N",
|
||||
string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
string_format("YaRN: extrapolation mix factor (default: %.2f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_ext_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-attn-factor"}, "N",
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.2f)", (double)params.yarn_attn_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_attn_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-slow"}, "N",
|
||||
string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
|
||||
string_format("YaRN: high correction dim or alpha (default: %.2f)", (double)params.yarn_beta_slow),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_slow = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-fast"}, "N",
|
||||
string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
|
||||
string_format("YaRN: low correction dim or beta (default: %.2f)", (double)params.yarn_beta_fast),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_fast = std::stof(value);
|
||||
}
|
||||
@@ -3331,14 +3335,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-split"}, "P",
|
||||
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
|
||||
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.p_split),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_split = std::stof(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-min"}, "P",
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_min = std::stof(value);
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ static void parse_json_tool_calls(
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
@@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
@@ -1630,12 +1630,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
||||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
@@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
@@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error {
|
||||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_syntax syntax_;
|
||||
common_chat_parser_params syntax_; // TODO: rename to params
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
const common_chat_syntax & syntax() const { return syntax_; }
|
||||
const common_chat_parser_params & syntax() const { return syntax_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
|
||||
+130
-110
@@ -7,9 +7,6 @@
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/value.h"
|
||||
#include "jinja/runtime.h"
|
||||
@@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
||||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msg::to_json_oaicompat() const
|
||||
{
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!reasoning_content.empty()) {
|
||||
message["reasoning_content"] = reasoning_content;
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (content.empty() && !tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
json jmsg {
|
||||
{"role", role},
|
||||
};
|
||||
if (!content.empty()) {
|
||||
jmsg["content"] = content;
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message["content"] = content;
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
if (!tool_name.empty()) {
|
||||
jmsg["name"] = tool_name;
|
||||
}
|
||||
if (!tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = tool_call_id;
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
arr.push_back({
|
||||
jmsg["tool_calls"] = json::array();
|
||||
auto & jtool_calls = jmsg["tool_calls"];
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
// // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// // We only generate a random id for the ones that don't generate one by themselves
|
||||
// // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
jtool_calls.push_back(tc);
|
||||
}
|
||||
message["tool_calls"] = arr;
|
||||
}
|
||||
return message;
|
||||
|
||||
return jmsg;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
@@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
||||
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
std::vector<common_chat_msg> msgs;
|
||||
|
||||
@@ -350,80 +380,15 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
return msgs;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
json jmsg {
|
||||
{"role", msg.role},
|
||||
};
|
||||
if (!msg.content.empty()) {
|
||||
jmsg["content"] = msg.content;
|
||||
} else if (!msg.content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : msg.content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
}
|
||||
if (!msg.tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
@@ -459,12 +424,6 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
@@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
|
||||
return result;
|
||||
}
|
||||
|
||||
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
@@ -601,18 +560,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
|
||||
return tmpls->has_explicit_template;
|
||||
}
|
||||
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
||||
if (variant != nullptr) {
|
||||
if (strcmp(variant, "tool_use") == 0) {
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
|
||||
if (!variant.empty()) {
|
||||
if (variant == "tool_use") {
|
||||
if (tmpls->template_tool_use) {
|
||||
return tmpls->template_tool_use->source().c_str();
|
||||
return tmpls->template_tool_use->source();
|
||||
}
|
||||
return nullptr;
|
||||
return "";
|
||||
} else {
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
|
||||
}
|
||||
}
|
||||
return tmpls->template_default->source().c_str();
|
||||
return tmpls->template_default->source();
|
||||
}
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
@@ -2691,6 +2650,51 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// This template does not support tools or reasoning
|
||||
// we just need to transform the messages into the correct schema
|
||||
|
||||
templates_params inputs_new = inputs;
|
||||
json & messages = inputs_new.messages;
|
||||
|
||||
// default to chat_template_kwargs, or en-GB if not specified
|
||||
std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB");
|
||||
std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB");
|
||||
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role") && message["role"].get<std::string>() != "user") {
|
||||
continue;
|
||||
}
|
||||
if (!message.contains("content")) {
|
||||
message["content"] = json::array();
|
||||
}
|
||||
if (message.contains("content") && !message["content"].is_array()) {
|
||||
auto content_str = message["content"].get<std::string>();
|
||||
// default to en-GB if not specified (to make common_chat_format_example works)
|
||||
auto src_lang = message.contains("source_lang_code")
|
||||
? message["source_lang_code"].get<std::string>() : default_src_lang;
|
||||
auto tgt_lang = message.contains("target_lang_code")
|
||||
? message["target_lang_code"].get<std::string>() : default_tgt_lang;
|
||||
message["content"] = json::array({
|
||||
json{
|
||||
{"type", "text"},
|
||||
{"text", content_str},
|
||||
{"source_lang_code", src_lang},
|
||||
{"target_lang_code", tgt_lang},
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -2867,13 +2871,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
@@ -2943,6 +2947,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
if (!params.extra_context.contains("clear_thinking")) {
|
||||
// by default, do not clear reasoning_content (added since GLM-4.7)
|
||||
params.extra_context["clear_thinking"] = false;
|
||||
}
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
@@ -3082,6 +3090,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// TranslateGemma
|
||||
if (src.find("[source_lang_code]") != std::string::npos &&
|
||||
src.find("[target_lang_code]") != std::string::npos) {
|
||||
return common_chat_params_init_translate_gemma(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
@@ -3174,3 +3188,9 @@ common_chat_params common_chat_templates_apply(
|
||||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
||||
+33
-17
@@ -10,6 +10,8 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
struct common_chat_tool_call {
|
||||
@@ -26,6 +28,11 @@ struct common_chat_msg_content_part {
|
||||
std::string type;
|
||||
std::string text;
|
||||
|
||||
// TODO @ngxson : no known chat templates support reasoning_content in content parts yet
|
||||
// this can be useful for models with interleaved thinking (like Kimi-K2)
|
||||
// if you see any templates explicitly support this, please ping me
|
||||
// std::string reasoning_content;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
@@ -40,7 +47,7 @@ struct common_chat_msg {
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
@@ -145,7 +152,7 @@ struct common_chat_templates_inputs {
|
||||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
@@ -165,14 +172,21 @@ struct common_chat_params {
|
||||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
// per-message parsing syntax
|
||||
// should be derived from common_chat_params
|
||||
struct common_chat_parser_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
@@ -191,7 +205,7 @@ common_chat_templates_ptr common_chat_templates_init(
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
@@ -213,23 +227,25 @@ std::string common_chat_format_example(
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
||||
@@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
|
||||
extern const char * LLAMA_COMPILER;
|
||||
extern const char * LLAMA_BUILD_TARGET;
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
struct common_control_vector_load_info;
|
||||
|
||||
//
|
||||
@@ -284,6 +286,7 @@ struct common_params_diffusion {
|
||||
};
|
||||
|
||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||
// only used by server
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
||||
|
||||
+19
-14
@@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
|
||||
// download one single file from remote URL to local path
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
|
||||
if (!bearer_token.empty()) {
|
||||
default_headers.insert({"Authorization", "Bearer " + bearer_token});
|
||||
}
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : custom_headers) {
|
||||
default_headers.emplace(h.first, h.second);
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
cli.set_default_headers(default_headers);
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
@@ -437,10 +440,12 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
||||
const common_remote_params & params) {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
|
||||
|
||||
for (const auto & header : params.headers) {
|
||||
headers.emplace(header.first, header.second);
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : params.headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
|
||||
if (params.timeout > 0) {
|
||||
|
||||
@@ -57,6 +57,17 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
throw std::runtime_error("error: invalid URL format");
|
||||
}
|
||||
|
||||
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (parts.scheme == "https") {
|
||||
throw std::runtime_error(
|
||||
"HTTPS is not supported. Please rebuild with:\n"
|
||||
" -DLLAMA_BUILD_BORINGSSL=ON\n"
|
||||
" -DLLAMA_BUILD_LIBRESSL=ON\n"
|
||||
"or ensure dev files of an OpenSSL-compatible library are available when building."
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||
|
||||
if (!parts.user.empty()) {
|
||||
|
||||
+48
-5
@@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) {
|
||||
ops.c_str());
|
||||
}
|
||||
|
||||
std::map<std::string, bool> caps::to_map() const {
|
||||
return {
|
||||
{"requires_typed_content", requires_typed_content},
|
||||
{"supports_tools", supports_tools},
|
||||
{"supports_tool_calls", supports_tool_calls},
|
||||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
{"supports_system_role", supports_system_role},
|
||||
{"supports_preserve_reasoning", supports_preserve_reasoning},
|
||||
};
|
||||
}
|
||||
|
||||
std::string caps::to_string() const {
|
||||
std::ostringstream ss;
|
||||
ss << "Caps(\n";
|
||||
ss << " requires_typed_content=" << requires_typed_content << "\n";
|
||||
ss << " supports_tools=" << supports_tools << "\n";
|
||||
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
|
||||
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
|
||||
ss << " supports_system_role=" << supports_system_role << "\n";
|
||||
for (const auto & [key, value] : to_map()) {
|
||||
ss << " " << key << "=" << (value ? "true" : "false") << "\n";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
@@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) {
|
||||
}
|
||||
);
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
{"reasoning_content", "Reasoning content"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", result.to_string().c_str());
|
||||
|
||||
return result;
|
||||
|
||||
+5
-1
@@ -3,6 +3,7 @@
|
||||
#include "runtime.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
@@ -11,14 +12,17 @@ struct caps {
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
bool requires_typed_content = false; // default: use string content
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
// for debugging
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
void debug_print_caps(const caps & c);
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -1005,6 +1005,7 @@ const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value { return mk_val<value_string>("None"); }}
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
||||
@@ -342,12 +342,12 @@ struct value_none_t : public value_t {
|
||||
virtual std::string type() const override { return "None"; }
|
||||
virtual bool is_none() const override { return true; }
|
||||
virtual bool as_bool() const override { return false; }
|
||||
virtual string as_string() const override { return string("None"); }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
};
|
||||
using value_none = std::shared_ptr<value_none_t>;
|
||||
|
||||
|
||||
struct value_undefined_t : public value_t {
|
||||
std::string hint; // for debugging, to indicate where undefined came from
|
||||
value_undefined_t(const std::string & h = "") : hint(h) {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
|
||||
+416
-574
File diff suppressed because it is too large
Load Diff
@@ -170,6 +170,7 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
set -e
|
||||
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -25,9 +26,13 @@ mkdir -p ppl
|
||||
OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld"
|
||||
echo "Model: $CONVERTED_MODEL"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
-f ppl/wikitext-2-raw/wiki.test.raw \
|
||||
--kl-divergence-base $OUTPUTFILE
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
@@ -20,8 +21,12 @@ if [ ! -d "ppl/wikitext-2-raw" ]; then
|
||||
popd
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
LOGITS_FILE="${1:-"$LOGITS_FILE"}"
|
||||
LOGITS_FILE="${2:-"$LOGITS_FILE"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
@@ -18,11 +19,15 @@ if [ ! -f ${LOGITS_FILE} ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo "Model: $QUANTIZED_MODEL"
|
||||
echo "Data file: $LOGITS_FILE"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
--kl-divergence-base $LOGITS_FILE \
|
||||
--kl-divergence
|
||||
|
||||
@@ -6,6 +6,7 @@ CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
|
||||
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
|
||||
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
|
||||
BUILD_DIR="${5:-"$BUILD_DIR"}"
|
||||
QUANTIZED_MODEL=$CONVERTED_MODEL
|
||||
|
||||
# Final check if we have a model path
|
||||
@@ -33,12 +34,16 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-quantize -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
cmake --build $BUILD_DIR --target llama-quantize -j8
|
||||
|
||||
echo $TOKEN_EMBD_TYPE
|
||||
echo $OUTPUT_TYPE
|
||||
|
||||
CMD_ARGS=("../../build/bin/llama-quantize")
|
||||
CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize")
|
||||
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
|
||||
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
|
||||
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
|
||||
|
||||
@@ -4,6 +4,7 @@ set -e
|
||||
#
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
@@ -13,10 +14,14 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-server
|
||||
cmake --build $BUILD_DIR --target llama-server
|
||||
|
||||
../../build/bin/llama-server -m $CONVERTED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \
|
||||
--embedding \
|
||||
--pooling none
|
||||
|
||||
@@ -77,39 +77,23 @@
|
||||
#include "ggml-zendnn.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static std::string path_str(const fs::path & path) {
|
||||
std::string u8path;
|
||||
try {
|
||||
#if defined(__cpp_lib_char8_t)
|
||||
// C++20 and later: u8string() returns std::u8string
|
||||
std::u8string u8str = path.u8string();
|
||||
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
|
||||
const std::u8string u8str = path.u8string();
|
||||
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
|
||||
#else
|
||||
// C++17: u8string() returns std::string
|
||||
u8path = path.u8string();
|
||||
return path.u8string();
|
||||
#endif
|
||||
} catch (...) {
|
||||
return std::string();
|
||||
}
|
||||
return u8path;
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
@@ -38,9 +38,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -48,9 +49,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -70,12 +72,14 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
@@ -94,9 +98,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -104,9 +109,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -126,9 +132,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -136,9 +143,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -165,18 +173,20 @@
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -202,9 +212,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -212,9 +223,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -242,9 +254,10 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
@@ -252,9 +265,10 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
||||
@@ -25,9 +25,8 @@
|
||||
#define UNUSED GGML_UNUSED
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
|
||||
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
||||
int16x8_t * out_mins,
|
||||
int8_t * out_scales) {
|
||||
// Helper for decoding scales and mins of Q4_K and Q5_K block formats
|
||||
static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
@@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[ncols_interleaved / 4];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < ncols_interleaved / 4; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// 2 sb each iteration
|
||||
int32x4_t acc_lo[col_pairs];
|
||||
int32x4_t acc_hi[col_pairs];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_pairs; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
|
||||
|
||||
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
||||
int8x16_t q8_qs[8];
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
||||
}
|
||||
|
||||
// Q5s column pair loop unrolled
|
||||
{
|
||||
// Cols 01
|
||||
uint8x16_t qs_0 = vld1q_u8(qs_base);
|
||||
uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
|
||||
uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
|
||||
uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
|
||||
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
|
||||
|
||||
qh[0][0] = vshrq_n_u8(qh[0][0], 2);
|
||||
qh[0][1] = vshrq_n_u8(qh[0][1], 2);
|
||||
qh[0][2] = vshrq_n_u8(qh[0][2], 2);
|
||||
qh[0][3] = vshrq_n_u8(qh[0][3], 2);
|
||||
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[0] = ggml_vdotq_s32(
|
||||
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 23
|
||||
qs_0 = vld1q_u8(qs_base + 16);
|
||||
qs_1 = vld1q_u8(qs_base + 80);
|
||||
qs_2 = vld1q_u8(qs_base + 144);
|
||||
qs_3 = vld1q_u8(qs_base + 208);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[1][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[1][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[1][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[1][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
|
||||
|
||||
qh[1][0] = vshrq_n_u8(qh[1][0], 2);
|
||||
qh[1][1] = vshrq_n_u8(qh[1][1], 2);
|
||||
qh[1][2] = vshrq_n_u8(qh[1][2], 2);
|
||||
qh[1][3] = vshrq_n_u8(qh[1][3], 2);
|
||||
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[1] = ggml_vdotq_s32(
|
||||
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 32);
|
||||
qs_1 = vld1q_u8(qs_base + 96);
|
||||
qs_2 = vld1q_u8(qs_base + 160);
|
||||
qs_3 = vld1q_u8(qs_base + 224);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[2][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[2][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[2][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[2][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
|
||||
|
||||
qh[2][0] = vshrq_n_u8(qh[2][0], 2);
|
||||
qh[2][1] = vshrq_n_u8(qh[2][1], 2);
|
||||
qh[2][2] = vshrq_n_u8(qh[2][2], 2);
|
||||
qh[2][3] = vshrq_n_u8(qh[2][3], 2);
|
||||
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[2] = ggml_vdotq_s32(
|
||||
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
|
||||
// Cols 45
|
||||
qs_0 = vld1q_u8(qs_base + 48);
|
||||
qs_1 = vld1q_u8(qs_base + 112);
|
||||
qs_2 = vld1q_u8(qs_base + 176);
|
||||
qs_3 = vld1q_u8(qs_base + 240);
|
||||
|
||||
hbit_lo_0 = vandq_u8(qh[3][0], mone);
|
||||
hbit_lo_1 = vandq_u8(qh[3][1], mone);
|
||||
hbit_lo_2 = vandq_u8(qh[3][2], mone);
|
||||
hbit_lo_3 = vandq_u8(qh[3][3], mone);
|
||||
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
|
||||
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
|
||||
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
|
||||
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
|
||||
|
||||
qh[3][0] = vshrq_n_u8(qh[3][0], 2);
|
||||
qh[3][1] = vshrq_n_u8(qh[3][1], 2);
|
||||
qh[3][2] = vshrq_n_u8(qh[3][2], 2);
|
||||
qh[3][3] = vshrq_n_u8(qh[3][3], 2);
|
||||
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
|
||||
acc_lo[3] = ggml_vdotq_s32(
|
||||
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
|
||||
q8_qs[4]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
|
||||
q8_qs[5]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
|
||||
q8_qs[6]);
|
||||
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
|
||||
q8_qs[7]);
|
||||
}
|
||||
|
||||
// Prepare bsum vectors for bias computation
|
||||
// Each pair of subblocks share the same bsums
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
||||
// p = 0 -> 0123 p2 -> 4567
|
||||
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
||||
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
|
||||
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
|
||||
int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
|
||||
int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
|
||||
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
||||
float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
|
||||
|
||||
// 0123 or 4567
|
||||
float32x4_t sumf_0 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
||||
|
||||
float32x4_t sumf_1 =
|
||||
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
||||
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
||||
|
||||
// FUSED BIAS: Compute and subtract bias immediately
|
||||
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
|
||||
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
|
||||
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
|
||||
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
|
||||
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
|
||||
}
|
||||
} // for sb
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q4sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
||||
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
||||
}
|
||||
|
||||
@@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
@@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int col_pairs = ncols_interleaved / 2;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
float32x4_t acc_f32[blocklen];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[4][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
||||
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
||||
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
||||
for (int i = 0; i < 8; i++) {
|
||||
acc[i] = vdupq_n_s32(0);
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
// Load qh once per block and shift after each subblock
|
||||
const uint8_t * qh_base = q5_ptr[b].qh;
|
||||
uint8x16_t qh[col_pairs][4];
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
|
||||
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
|
||||
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
|
||||
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int8_t q5sb_scales[2][8];
|
||||
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
|
||||
for (int i = 0; i < 2; i++) {
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
int8x16_t q8_qs_01[8];
|
||||
int8x16_t q8_qs_23[8];
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const int offset = i * 32; // 16 for row 01, 16 for row 23
|
||||
q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
||||
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
||||
}
|
||||
|
||||
const int8x16_t q8s[2][8] = {
|
||||
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
|
||||
q8_qs_01[7] },
|
||||
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
|
||||
q8_qs_23[7] },
|
||||
};
|
||||
|
||||
// Q5s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < col_pairs; cp++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
sb_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
||||
uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
||||
uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
||||
uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
||||
|
||||
// This is the only part of the algorithm that differs with Q4_K
|
||||
// Extract High bits and pack into 5 bit weights
|
||||
uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
|
||||
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
|
||||
qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
|
||||
// Same as Q4_K, i8mm to dequantize the weights.
|
||||
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
|
||||
int32x4_t acc_0 = sb_acc[0];
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
|
||||
int32x4_t acc_2 = sb_acc[2];
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
|
||||
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
|
||||
int32x4_t acc_1 = sb_acc[1];
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
|
||||
int32x4_t acc_3 = sb_acc[3];
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
|
||||
|
||||
// Repeat for the other 3 columns (8..15, 16..23, 24..31)
|
||||
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
|
||||
uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
|
||||
qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
|
||||
const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
|
||||
const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
|
||||
|
||||
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
|
||||
uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
|
||||
qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
|
||||
const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
|
||||
const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
|
||||
|
||||
uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
|
||||
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
|
||||
qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
|
||||
const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
|
||||
acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
|
||||
sb_acc[0] = acc_0;
|
||||
acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
|
||||
sb_acc[2] = acc_2;
|
||||
|
||||
// Scales[i] corresponds to column i
|
||||
const int scale_offset = cp * 2;
|
||||
const int32_t s0 = q5sb_scales[0][scale_offset];
|
||||
const int32_t s1 = q5sb_scales[0][scale_offset + 1];
|
||||
const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
|
||||
|
||||
const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
|
||||
acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
|
||||
sb_acc[1] = acc_1;
|
||||
acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
|
||||
sb_acc[3] = acc_3;
|
||||
|
||||
const int32_t s2 = q5sb_scales[1][scale_offset];
|
||||
const int32_t s3 = q5sb_scales[1][scale_offset + 1];
|
||||
const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
|
||||
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
|
||||
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
|
||||
}
|
||||
|
||||
// Multiply Acc bsum + mins
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
// Each pair of subblocks share the same bsums
|
||||
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
||||
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
||||
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
||||
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row] =
|
||||
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * q8_row + 1] =
|
||||
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
// Reorder of i8mm output with bias and output layout
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
||||
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
||||
}
|
||||
int32x4_t reorder_acc[8] = {
|
||||
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
||||
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
||||
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
||||
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
||||
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
||||
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
||||
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
||||
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
||||
};
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
||||
float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
|
||||
const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
|
||||
|
||||
float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
|
||||
const float32x4_t scale = vmulq_f32(q5_d, q8_d);
|
||||
|
||||
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
||||
acc_f32[2 * i + j] =
|
||||
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
||||
}
|
||||
}
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
|
||||
+345
-15
@@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
@@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
|
||||
out.scales[i] = in[src1].scales[src2];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
|
||||
block_q5_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
for (int i = 0; i < end / 4; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
// The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
|
||||
// Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
||||
// The output Q5_Kx8 structure has 96 bytes
|
||||
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
|
||||
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
|
||||
uint8_t s[8], m[8];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = in[j].scales[i] & 63;
|
||||
m[j] = in[j].scales[i + 4] & 63;
|
||||
}
|
||||
|
||||
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
|
||||
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
|
||||
}
|
||||
|
||||
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
@@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
int interleave_block,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
const block_q5_K * src = (const block_q5_K *) data;
|
||||
block_q5_K dst_tmp[8];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
@@ -1936,6 +2245,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
@@ -1973,6 +2286,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -1981,8 +2298,8 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
@@ -2013,20 +2330,24 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
||||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
@@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
||||
|
||||
@@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q2_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q5_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -44,6 +44,7 @@ struct block_q4_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
@@ -52,6 +53,18 @@ struct block_q2_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
|
||||
uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
|
||||
"wrong q5_K block size/padding");
|
||||
|
||||
struct block_q8_Kx4 {
|
||||
float d[4]; // delta
|
||||
int8_t qs[QK_K * 4]; // quants
|
||||
@@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
|
||||
# define STRIDED_ITERATOR_AVAILABLE
|
||||
# endif
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
@@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef STRIDED_ITERATOR_AVAILABLE
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
#endif // STRIDED_ITERATOR_AVAILABLE
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
@@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
#ifdef STRIDED_ITERATOR_AVAILABLE
|
||||
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
||||
#else
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
int * offset_iterator = offsets_alloc.get();
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
|
||||
#endif
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
@@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, stream);
|
||||
offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
@@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
@@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
stream);
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1327,10 +1327,44 @@ struct ggml_backend_cuda_context {
|
||||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Map from first_node_ptr to cuda_graph - allows multiple graphs per context
|
||||
// when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
|
||||
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
|
||||
|
||||
ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
|
||||
auto it = cuda_graphs.find(first_node_ptr);
|
||||
if (it == cuda_graphs.end()) {
|
||||
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
|
||||
return cuda_graphs[first_node_ptr].get();
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Check if any CUDA graph is enabled for this context (used by kernels that need to know
|
||||
// if graphs are in use without having access to the specific graph key)
|
||||
bool any_cuda_graph_enabled() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->is_enabled()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any CUDA graph has an instance for this context
|
||||
bool any_cuda_graph_has_instance() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->instance != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
|
||||
@@ -778,13 +778,11 @@ void launch_fattn(
|
||||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
@@ -794,9 +792,9 @@ void launch_fattn(
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
@@ -817,10 +815,10 @@ void launch_fattn(
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
@@ -849,32 +847,39 @@ void launch_fattn(
|
||||
K_data = (char *) K_f16.ptr;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V_is_K_view) {
|
||||
V_data = K_data;
|
||||
nb21 = nb11;
|
||||
nb22 = nb12;
|
||||
nb23 = nb13;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
|
||||
@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
|
||||
bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
|
||||
typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
@@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||
@@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
|
||||
const int k_VKQ_0 = kb0 * nbatch_fa;
|
||||
#if defined(TURING_MMA_AVAILABLE)
|
||||
@@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
|
||||
static_assert(!mla, "multi-stage loading not implemented for MLA");
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
@@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over K in reverse and later re-use the data if possible.
|
||||
#pragma unroll
|
||||
for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
|
||||
for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
|
||||
const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
|
||||
const int k0_diff = k0_stop - k0_start;
|
||||
|
||||
@@ -510,7 +511,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||
@@ -522,14 +522,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
T_A_KQ K_A;
|
||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||
|
||||
// Wide version of KQ_C is column-major
|
||||
if constexpr (cols_per_warp == 8) {
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
} else {
|
||||
// Wide version of KQ_C is column-major
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
// RDNA matrix C is column-major.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||
#else
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
// swap A and B for CUDA.
|
||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -773,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
if constexpr (nstages > 1) {
|
||||
static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
|
||||
// Preload K tile for next iteration:
|
||||
constexpr bool use_cp_async = true;
|
||||
cp_async_wait_all();
|
||||
@@ -788,10 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
}
|
||||
|
||||
|
||||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
T_A_VKQ A_identity;
|
||||
make_identity_mat(A_identity);
|
||||
@@ -799,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
|
||||
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
||||
#pragma unroll
|
||||
for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
|
||||
const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
|
||||
static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
|
||||
const int i0_stop = i0_start + 2*nbatch_V2;
|
||||
const int i0_diff = i0_stop - i0_start;
|
||||
|
||||
if constexpr (nstages <= 1) {
|
||||
if (i0_start < reusable_cutoff) {
|
||||
if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
|
||||
constexpr bool use_cp_async = nstages == 1;
|
||||
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
|
||||
(V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
|
||||
@@ -814,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
||||
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
|
||||
|
||||
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
||||
@@ -917,7 +919,7 @@ template<int ncols> struct mma_tile_sizes {
|
||||
};
|
||||
#endif // defined(TURING_MMA_AVAILABLE)
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
|
||||
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const float2 * const __restrict__ Q_f2,
|
||||
const half2 * const __restrict__ K_h2,
|
||||
@@ -953,7 +955,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
|
||||
constexpr int cols_per_warp = T_B_KQ::I;
|
||||
constexpr int cols_per_thread = get_cols_per_thread();
|
||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||
@@ -971,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr int stride_tile_Q = DKQ/2 + 4;
|
||||
constexpr int stride_tile_K = nbatch_K2 + 4;
|
||||
|
||||
static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
|
||||
constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
|
||||
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
|
||||
|
||||
extern __shared__ half2 tile_Q[];
|
||||
@@ -1076,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1085,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = true;
|
||||
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1096,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1105,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
constexpr bool last_iter = true;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
@@ -1453,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
||||
}
|
||||
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
||||
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
|
||||
__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
@@ -1484,6 +1485,13 @@ static __global__ void flash_attn_ext_f16(
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if (ncols1*ncols2 < 32) {
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
|
||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||
if (ncols1*ncols2 > 32) {
|
||||
NO_DEVICE_CODE;
|
||||
@@ -1498,8 +1506,6 @@ static __global__ void flash_attn_ext_f16(
|
||||
}
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
|
||||
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
||||
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
|
||||
@@ -1512,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int stride_K = nb11 / sizeof(half2);
|
||||
const int stride_mask = nb31 / sizeof(half);
|
||||
|
||||
const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
|
||||
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
||||
@@ -1542,7 +1548,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
@@ -1553,12 +1559,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
}
|
||||
@@ -1586,7 +1592,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
@@ -1597,7 +1603,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
|
||||
#else
|
||||
@@ -1633,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
constexpr bool mla = DKQ == 576;
|
||||
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
|
||||
|
||||
const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
|
||||
const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
|
||||
@@ -1658,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
fattn_kernel_t fattn_kernel;
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1669,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
#endif // !defined(GGML_USE_MUSA)
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
||||
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
|
||||
|
||||
#if !defined(GGML_USE_MUSA)
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||
@@ -1728,3 +1734,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||
|
||||
// For GLM 4.7 Flash
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
|
||||
return 0;
|
||||
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||
|
||||
return 0;
|
||||
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||
|
||||
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||
|
||||
@@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
|
||||
@@ -46,7 +46,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
// are put into the template specialization without GQA optimizations.
|
||||
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||
if (gqa_ratio % 16 == 0) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -232,7 +236,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
if (t == nullptr || ggml_is_quantized(t->type)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
@@ -243,6 +247,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
@@ -262,7 +268,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2969,18 +2969,25 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
||||
return true;
|
||||
}
|
||||
|
||||
static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
||||
return cgraph->nodes[0];
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (graph->instance == nullptr) {
|
||||
res = true;
|
||||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
res = true;
|
||||
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
@@ -2988,37 +2995,38 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
bool props_match= true;
|
||||
bool props_match = true;
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
}
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
|
||||
#else
|
||||
cudaGraphNode_t errorNode;
|
||||
cudaGraphExecUpdateResult result_info;
|
||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
||||
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||
@@ -3029,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
(void)cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||
cuda_ctx->cuda_graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
|
||||
graph->instance = nullptr;
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
} else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
const ggml_tensor * view,
|
||||
@@ -3241,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
|
||||
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
// flag used to determine whether it is an integrated_gpu
|
||||
@@ -3695,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||
cuda_ctx->cuda_graph->graph = nullptr;
|
||||
if (graph->graph != nullptr) {
|
||||
CUDA_CHECK(cudaGraphDestroy(graph->graph));
|
||||
graph->graph = nullptr;
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
@@ -3714,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
if (cuda_graph_update_required) { // Update graph executable
|
||||
ggml_cuda_graph_update_executable(cuda_ctx);
|
||||
ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
|
||||
#else
|
||||
graph_evaluated_or_captured = true;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
||||
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
|
||||
if (cuda_ctx->cuda_graph == nullptr) {
|
||||
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
|
||||
}
|
||||
|
||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||
if (graph->graph == nullptr) {
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||
if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
|
||||
if (!graph->disable_due_to_gpu_arch) {
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
|
||||
}
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
|
||||
graph->disable_due_to_gpu_arch = true;
|
||||
}
|
||||
}
|
||||
|
||||
return cuda_ctx->cuda_graph->is_enabled();
|
||||
return graph->is_enabled();
|
||||
#else
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(graph_key);
|
||||
return false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
}
|
||||
@@ -3759,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
const void * graph_key = nullptr;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
||||
if (cuda_ctx->cuda_graph->is_enabled()) {
|
||||
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
|
||||
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
|
||||
if (graph->is_enabled()) {
|
||||
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
|
||||
|
||||
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
graph->record_update(use_cuda_graph, cuda_graph_update_required);
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
@@ -3781,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
|
||||
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -3814,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
|
||||
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
||||
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
|
||||
#else
|
||||
const bool use_cuda_graph = false;
|
||||
GGML_UNUSED(cuda_ctx);
|
||||
GGML_UNUSED(cgraph);
|
||||
#endif
|
||||
|
||||
static bool enable_graph_optimization = [] {
|
||||
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
|
||||
|
||||
@@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
if ((nrows == 1) &&
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// CUDA_GRAPHS_DISABLED
|
||||
((ncols > 65536) &&
|
||||
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled())) ||
|
||||
// CUDA_GRAPHS ENABLED
|
||||
((ncols > 32768) &&
|
||||
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
|
||||
ctx.cuda_graph->is_enabled()))) {
|
||||
// Determine if CUDA graphs are effectively disabled for this context
|
||||
// (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
|
||||
(((ncols > 65536) &&
|
||||
(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())) ||
|
||||
// CUDA graphs are enabled - use lower threshold
|
||||
((ncols > 32768) &&
|
||||
!(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
|
||||
ctx.any_cuda_graph_enabled())))) {
|
||||
#else
|
||||
(ncols > 65536)) {
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||
|
||||
@@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16):
|
||||
continue
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
|
||||
# include <cuda/iterator>
|
||||
# define CUB_TOP_K_AVAILABLE
|
||||
using namespace cub;
|
||||
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <assert.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
@@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * v (float)
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
@@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
uint32_t ic = 0;
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
@@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
// 4. Online Softmax Update
|
||||
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
|
||||
float ms = expf(M_old - M_new);
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
S = S * ms;
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
|
||||
float p_sum = hvx_vec_get_f32(p_sum_vec);
|
||||
S += p_sum;
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
}
|
||||
|
||||
// Leftover
|
||||
|
||||
@@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
op->src[0]->ne[0] != 112 &&
|
||||
op->src[0]->ne[0] != 128 &&
|
||||
op->src[0]->ne[0] != 192 &&
|
||||
op->src[0]->ne[0] != 256) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 576) {
|
||||
// DeepSeek sizes
|
||||
// TODO: disabled for now, until optmized
|
||||
op->src[0]->ne[0] != 256 &&
|
||||
op->src[0]->ne[0] != 576) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
||||
@@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
|
||||
int32_t nsg = 4;
|
||||
int32_t nsg = ne00 >= 512 ? 8 : 4;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
|
||||
@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
|
||||
|
||||
constexpr short NC = (C/8)/NSG;
|
||||
|
||||
// note: do not unroll for large heads
|
||||
#pragma unroll (DK <= 64 ? NC : 1)
|
||||
for (short cc = 0; cc < NC; ++cc) {
|
||||
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||
|
||||
if (DK % 16 != 0) {
|
||||
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
|
||||
k8x8_t mk[2];
|
||||
q8x8_t mq[2];
|
||||
|
||||
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
|
||||
// note: too much unroll can tank the performance for large heads
|
||||
#pragma unroll (MIN(DK8/2, 4*NSG))
|
||||
for (short i = 0; i < DK8/2; ++i) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
|
||||
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
|
||||
pv += 8*NS20;
|
||||
}
|
||||
} else {
|
||||
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
|
||||
constexpr short NC = (C/8)/2;
|
||||
|
||||
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
|
||||
s8x8_t vs[2];
|
||||
|
||||
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
|
||||
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
|
||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
|
||||
@@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS
|
||||
add
|
||||
add_id
|
||||
argsort
|
||||
tri
|
||||
fill
|
||||
clamp
|
||||
cpy
|
||||
|
||||
@@ -398,6 +398,7 @@ struct ggml_backend_opencl_context {
|
||||
int adreno_wave_size;
|
||||
|
||||
cl_bool non_uniform_workgroups;
|
||||
size_t image_max_buffer_size;
|
||||
|
||||
cl_context context;
|
||||
cl_command_queue queue;
|
||||
@@ -407,6 +408,10 @@ struct ggml_backend_opencl_context {
|
||||
ggml_cl_buffer prealloc_scales_trans;
|
||||
ggml_cl_buffer prealloc_act_trans;
|
||||
|
||||
// prealloc buffers for src0 and src1
|
||||
ggml_cl_buffer prealloc_src0;
|
||||
ggml_cl_buffer prealloc_src1;
|
||||
|
||||
cl_program program_add;
|
||||
cl_program program_add_id;
|
||||
cl_program program_clamp;
|
||||
@@ -489,6 +494,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
|
||||
cl_kernel kernel_relu;
|
||||
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
|
||||
cl_kernel kernel_tri;
|
||||
cl_kernel kernel_fill;
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
|
||||
@@ -793,6 +799,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// tri
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "tri.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("tri.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
}
|
||||
|
||||
// fill
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2639,6 +2663,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size);
|
||||
|
||||
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
|
||||
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
|
||||
|
||||
@@ -3205,6 +3232,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_TRI:
|
||||
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
|
||||
case GGML_OP_FILL:
|
||||
return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
|
||||
case GGML_OP_CLAMP:
|
||||
@@ -4690,6 +4719,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg
|
||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
|
||||
}
|
||||
|
||||
// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but
|
||||
// nb[] is recalculated such that tensor is contiguous.
|
||||
static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,
|
||||
cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
const int tensor_type_size = ggml_type_size(src->type);
|
||||
|
||||
const int ne00 = src->ne[0];
|
||||
const int ne01 = src->ne[1];
|
||||
const int ne02 = src->ne[2];
|
||||
const int ne03 = src->ne[3];
|
||||
|
||||
const cl_ulong nb00 = src->nb[0];
|
||||
const cl_ulong nb01 = src->nb[1];
|
||||
const cl_ulong nb02 = src->nb[2];
|
||||
const cl_ulong nb03 = src->nb[3];
|
||||
|
||||
const int ne0 = src->ne[0];
|
||||
const int ne1 = src->ne[1];
|
||||
const int ne2 = src->ne[2];
|
||||
const int ne3 = src->ne[3];
|
||||
|
||||
nb0 = tensor_type_size;
|
||||
nb1 = tensor_type_size*ne00;
|
||||
nb2 = tensor_type_size*ne00*ne01;
|
||||
nb3 = tensor_type_size*ne00*ne01*ne02;
|
||||
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
|
||||
|
||||
cl_ulong offset0 = extra->offset + src->view_offs;
|
||||
cl_ulong offsetd = 0;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
switch (src->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_cpy_f32_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_cpy_f16_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
|
||||
|
||||
const int nth = MIN(64, ne00);
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);
|
||||
}
|
||||
|
||||
static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
UNUSED(backend);
|
||||
UNUSED(src0);
|
||||
@@ -5965,6 +6069,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
UNUSED(src1);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int tri_type = ggml_get_op_params_i32(dst, 0);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_tri;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type));
|
||||
|
||||
size_t local_work_size[1] = { 256 };
|
||||
size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
|
||||
}
|
||||
|
||||
static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
@@ -7665,9 +7807,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
cl_context context = backend_ctx->context;
|
||||
|
||||
if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
|
||||
if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 &&
|
||||
// dst is wrapped with image1d_buffer, the size limit applies, also src0
|
||||
(ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {
|
||||
// For KQ
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
|
||||
((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&
|
||||
nb00 <= nb02 &&
|
||||
nb02 <= nb01 &&
|
||||
nb01 <= nb03 &&
|
||||
@@ -7678,7 +7823,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
return;
|
||||
}
|
||||
// For KQV
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
|
||||
((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {
|
||||
ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
|
||||
return;
|
||||
}
|
||||
@@ -7984,9 +8130,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
// GEMM using local memory
|
||||
// Current BK = 16, so ne00 % 16 == 0
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
if (src1t == GGML_TYPE_F32 &&
|
||||
ne00 % 16 == 0 &&
|
||||
ne11 > 1) {
|
||||
switch(src0t) {
|
||||
@@ -7998,10 +8142,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
@@ -8033,10 +8209,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
cl_mem mem_src0 = extra0->data_device;
|
||||
cl_mem mem_src1 = extra1->data_device;
|
||||
|
||||
cl_ulong nb00_cont = nb00;
|
||||
cl_ulong nb01_cont = nb01;
|
||||
cl_ulong nb02_cont = nb02;
|
||||
cl_ulong nb03_cont = nb03;
|
||||
|
||||
cl_ulong nb10_cont = nb10;
|
||||
cl_ulong nb11_cont = nb11;
|
||||
cl_ulong nb12_cont = nb12;
|
||||
cl_ulong nb13_cont = nb13;
|
||||
|
||||
cl_ulong offset0_cont = offset0;
|
||||
cl_ulong offset1_cont = offset1;
|
||||
|
||||
if (!ggml_is_contiguous(src0)) {
|
||||
backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
|
||||
ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
|
||||
nb00_cont, nb01_cont, nb02_cont, nb03_cont);
|
||||
mem_src0 = backend_ctx->prealloc_src0.buffer;
|
||||
offset0_cont = 0;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
|
||||
ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
|
||||
nb10_cont, nb11_cont, nb12_cont, nb13_cont);
|
||||
mem_src1 = backend_ctx->prealloc_src1.buffer;
|
||||
offset1_cont = 0;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
@@ -8064,6 +8272,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
@@ -10012,6 +10224,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_glu;
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_tri;
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// tri
|
||||
//------------------------------------------------------------------------------
|
||||
__kernel void kernel_tri_f32(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int n,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int tri_type
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int idx = get_global_id(0);
|
||||
if (idx >= n) return;
|
||||
|
||||
int i0 = idx % ne0;
|
||||
int i1 = (idx / ne0) % ne1;
|
||||
|
||||
int keep = 0;
|
||||
if (tri_type == 0) keep = (i0 >= i1);
|
||||
else if (tri_type == 1) keep = (i0 > i1);
|
||||
else if (tri_type == 2) keep = (i0 <= i1);
|
||||
else keep = (i0 < i1);
|
||||
|
||||
dst[idx] = keep ? src0[idx] : 0.0f;
|
||||
}
|
||||
@@ -1157,13 +1157,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
inline void * aligned_malloc_host(size_t alignment, size_t size) {
|
||||
#ifdef _WIN32
|
||||
return _aligned_malloc(size, alignment);
|
||||
#else
|
||||
return aligned_alloc(alignment, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void free_aligned_mem_host(void * memblock) {
|
||||
#ifdef _WIN32
|
||||
_aligned_free(memblock);
|
||||
#else
|
||||
free(memblock);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_sycl_host_free(buffer->context);
|
||||
free_aligned_mem_host((void *)buffer->context);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
void * ptr = ggml_sycl_host_malloc(size);
|
||||
|
||||
void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
|
||||
if (ptr == nullptr) {
|
||||
// fallback to cpu buffer
|
||||
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
||||
|
||||
@@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
|
||||
uint32_t fusion_flags;
|
||||
uint32_t nei0;
|
||||
uint32_t ne11;
|
||||
uint32_t expert_i1;
|
||||
uint32_t nbi1;
|
||||
};
|
||||
|
||||
struct vk_flash_attn_push_constants {
|
||||
@@ -1516,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants {
|
||||
uint32_t num_blocks;
|
||||
};
|
||||
|
||||
struct vk_op_flash_attn_split_k_reduce_push_constants {
|
||||
uint32_t D;
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
uint32_t k_num;
|
||||
uint32_t sinks;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
@@ -1802,7 +1813,6 @@ struct ggml_backend_vk_context {
|
||||
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
||||
|
||||
vk_context_ref compute_ctx;
|
||||
vk_context_ref transfer_ctx;
|
||||
|
||||
std::vector<vk_context_ref> tensor_ctxs;
|
||||
|
||||
@@ -1812,7 +1822,6 @@ struct ggml_backend_vk_context {
|
||||
uint32_t pipeline_descriptor_set_requirements {};
|
||||
|
||||
vk_command_pool compute_cmd_pool;
|
||||
vk_command_pool transfer_cmd_pool;
|
||||
|
||||
// number of additional consecutive nodes that are being fused with the
|
||||
// node currently being processed
|
||||
@@ -3178,15 +3187,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
@@ -3980,7 +3989,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
|
||||
if (device->subgroup_clustered && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
|
||||
@@ -5647,7 +5656,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
||||
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
||||
|
||||
ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
|
||||
ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
|
||||
@@ -8083,8 +8091,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
|
||||
const uint64_t nei0 = ids->ne[0];
|
||||
const uint64_t nei1 = ids->ne[1];
|
||||
|
||||
GGML_ASSERT(nei1 == 1);
|
||||
const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
|
||||
|
||||
const uint64_t ne20 = dst->ne[0];
|
||||
const uint64_t ne21 = dst->ne[1];
|
||||
@@ -8168,7 +8175,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||
}
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
|
||||
}
|
||||
|
||||
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
||||
@@ -8226,7 +8233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
||||
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
||||
stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
|
||||
}
|
||||
|
||||
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
||||
@@ -8262,23 +8269,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
|
||||
}
|
||||
|
||||
// compute
|
||||
const vk_mat_vec_id_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
||||
fusion_flags,
|
||||
(uint32_t)nei0, (uint32_t)ne11,
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
d_ids,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
// Loop over the batch dimension
|
||||
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
|
||||
const vk_mat_vec_id_push_constants pc = {
|
||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
||||
fusion_flags,
|
||||
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
|
||||
};
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
||||
{
|
||||
d_X,
|
||||
d_Y,
|
||||
d_D,
|
||||
d_F0,
|
||||
d_F1,
|
||||
d_ids,
|
||||
},
|
||||
pc, { groups_x, (uint32_t)nei0, groups_z });
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ctx->prealloc_x_need_sync = true;
|
||||
@@ -8292,7 +8301,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
|
||||
ggml_tensor * dst = cgraph->nodes[node_idx];
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
||||
return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
@@ -8454,14 +8463,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
workgroups_y /= gqa_ratio;
|
||||
}
|
||||
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
@@ -8523,6 +8532,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
|
||||
assert(pipeline);
|
||||
// Compile early to initialize wg_denoms.
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
uint32_t split_kv = KV;
|
||||
uint32_t split_k = 1;
|
||||
@@ -8530,22 +8541,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && shader_core_count > 0) {
|
||||
// Try to use split_k when KV is large enough to be worth the overhead.
|
||||
// Must either be a single batch or be using gqa, we can't mix the two.
|
||||
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
|
||||
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
workgroups_x = split_k;
|
||||
}
|
||||
}
|
||||
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
|
||||
// For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
|
||||
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
|
||||
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
@@ -8556,7 +8569,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
{
|
||||
// Request descriptor sets
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
||||
}
|
||||
@@ -8605,7 +8617,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
|
||||
@@ -8613,15 +8625,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
||||
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{split_k_buf, sinks_buf, dst_buf},
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
|
||||
ctx->prealloc_split_k_need_sync = true;
|
||||
} else {
|
||||
if (gqa_ratio > 1) {
|
||||
// When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
}
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
|
||||
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
@@ -11560,7 +11576,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||
free(d_chk);
|
||||
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
||||
|
||||
ggml_vk_destroy_buffer(d_X);
|
||||
ggml_vk_destroy_buffer(d_Y);
|
||||
@@ -12145,7 +12160,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
||||
ggml_vk_submit(subctx, {});
|
||||
ctx->submit_pending = true;
|
||||
ggml_vk_synchronize(ctx);
|
||||
GGML_ASSERT(ctx->compute_ctx.expired());
|
||||
ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
ctx->compute_ctx = subctx;
|
||||
}
|
||||
|
||||
if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
|
||||
@@ -12163,6 +12180,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_y);
|
||||
}
|
||||
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
|
||||
ctx->prealloc_y_last_tensor_used = nullptr;
|
||||
}
|
||||
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
|
||||
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
|
||||
@@ -12743,7 +12761,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
|
||||
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
|
||||
ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
|
||||
|
||||
for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
|
||||
ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
|
||||
@@ -12772,7 +12789,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||
static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
|
||||
// discard any unsubmitted command buffers
|
||||
ctx->transfer_ctx.reset();
|
||||
ctx->compute_ctx.reset();
|
||||
// wait for any pending command buffers to finish
|
||||
ggml_vk_synchronize(ctx);
|
||||
|
||||
@@ -12805,7 +12822,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ctx->descriptor_sets.clear();
|
||||
|
||||
ctx->compute_cmd_pool.destroy(ctx->device->device);
|
||||
ctx->transfer_cmd_pool.destroy(ctx->device->device);
|
||||
if (vk_perf_logger_enabled) {
|
||||
ctx->perf_logger->print_timings(true);
|
||||
}
|
||||
@@ -13077,34 +13093,34 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->transfer_ctx = transfer_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
|
||||
bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
|
||||
bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
|
||||
|
||||
if (!ret) {
|
||||
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
||||
ggml_vk_sync_buffers(nullptr, transfer_ctx);
|
||||
ggml_vk_sync_buffers(nullptr, compute_ctx);
|
||||
|
||||
vk::BufferCopy buffer_cpy;
|
||||
buffer_cpy.srcOffset = 0;
|
||||
buffer_cpy.dstOffset = dst_offset;
|
||||
buffer_cpy.size = size;
|
||||
|
||||
transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
||||
deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
|
||||
compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
|
||||
deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
|
||||
ggml_vk_synchronize(ctx);
|
||||
}
|
||||
}
|
||||
@@ -13116,34 +13132,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
|
||||
|
||||
ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->transfer_ctx = transfer_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer buf = buf_ctx->dev_buffer;
|
||||
|
||||
auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
|
||||
bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size);
|
||||
bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
|
||||
|
||||
// If that failed, copy synchronously through a staging buffer
|
||||
if (!ret) {
|
||||
ggml_vk_ensure_sync_staging_buffer(ctx, size);
|
||||
ggml_vk_sync_buffers(nullptr, transfer_ctx);
|
||||
ggml_vk_sync_buffers(nullptr, compute_ctx);
|
||||
|
||||
vk::BufferCopy buffer_cpy;
|
||||
buffer_cpy.srcOffset = src_offset;
|
||||
buffer_cpy.dstOffset = 0;
|
||||
buffer_cpy.size = size;
|
||||
|
||||
transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
|
||||
deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys);
|
||||
compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
|
||||
deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
|
||||
ggml_vk_synchronize(ctx);
|
||||
}
|
||||
}
|
||||
@@ -13155,21 +13171,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
|
||||
ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->transfer_ctx = transfer_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
vk_buffer src_buf = src_buf_ctx->dev_buffer;
|
||||
vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
|
||||
|
||||
ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
|
||||
ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -13179,19 +13195,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_
|
||||
static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
||||
VK_LOG_DEBUG("ggml_vk_synchronize()");
|
||||
|
||||
bool do_transfer = !ctx->transfer_ctx.expired();
|
||||
bool do_transfer = !ctx->compute_ctx.expired();
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
if (do_transfer) {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
|
||||
ggml_vk_ctx_end(transfer_ctx);
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
|
||||
for (auto& cpy : transfer_ctx->in_memcpys) {
|
||||
for (auto& cpy : compute_ctx->in_memcpys) {
|
||||
memcpy(cpy.dst, cpy.src, cpy.n);
|
||||
}
|
||||
|
||||
ggml_vk_submit(transfer_ctx, {});
|
||||
ggml_vk_submit(compute_ctx, {});
|
||||
ctx->submit_pending = true;
|
||||
}
|
||||
|
||||
@@ -13205,10 +13221,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
|
||||
}
|
||||
|
||||
if (do_transfer) {
|
||||
for (auto& cpy : transfer_ctx->out_memcpys) {
|
||||
for (auto& cpy : compute_ctx->out_memcpys) {
|
||||
memcpy(cpy.dst, cpy.src, cpy.n);
|
||||
}
|
||||
ctx->transfer_ctx.reset();
|
||||
ctx->compute_ctx.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13877,6 +13893,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
ggml_vk_submit(compute_ctx, ctx->device->fence);
|
||||
VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
|
||||
ctx->device->device.resetFences({ ctx->device->fence });
|
||||
ctx->compute_ctx.reset();
|
||||
|
||||
// Get the results and pass them to the logger
|
||||
std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
|
||||
@@ -14163,15 +14180,15 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->transfer_ctx = transfer_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
// the backend interface doesn't have an explicit reset, so reset it here
|
||||
@@ -14179,13 +14196,13 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
|
||||
ctx->device->device.resetEvent(vkev->event);
|
||||
ctx->device->device.resetFences({ vkev->fence });
|
||||
|
||||
ggml_vk_set_event(transfer_ctx, vkev->event);
|
||||
ggml_vk_set_event(compute_ctx, vkev->event);
|
||||
|
||||
ggml_vk_ctx_end(transfer_ctx);
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
|
||||
ggml_vk_submit(transfer_ctx, {vkev->fence});
|
||||
ggml_vk_submit(compute_ctx, {vkev->fence});
|
||||
ctx->submit_pending = true;
|
||||
ctx->transfer_ctx.reset();
|
||||
ctx->compute_ctx.reset();
|
||||
}
|
||||
|
||||
static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
|
||||
@@ -14193,20 +14210,20 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
vk_event *vkev = (vk_event *)event->context;
|
||||
|
||||
vk_context transfer_ctx;
|
||||
vk_context compute_ctx;
|
||||
|
||||
if (ctx->transfer_ctx.expired()) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
// Initialize new transfer context
|
||||
transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->transfer_ctx = transfer_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, transfer_ctx);
|
||||
compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
ctx->compute_ctx = compute_ctx;
|
||||
ggml_vk_ctx_begin(ctx->device, compute_ctx);
|
||||
} else {
|
||||
transfer_ctx = ctx->transfer_ctx.lock();
|
||||
compute_ctx = ctx->compute_ctx.lock();
|
||||
}
|
||||
|
||||
ggml_vk_wait_events(transfer_ctx, {vkev->event});
|
||||
ggml_vk_ctx_end(transfer_ctx);
|
||||
ctx->transfer_ctx.reset();
|
||||
ggml_vk_wait_events(compute_ctx, {vkev->event});
|
||||
ggml_vk_ctx_end(compute_ctx);
|
||||
ctx->compute_ctx.reset();
|
||||
}
|
||||
|
||||
// TODO: enable async and synchronize
|
||||
|
||||
@@ -53,7 +53,7 @@ void main() {
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
@@ -101,9 +101,9 @@ void main() {
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
@@ -320,7 +320,8 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
@@ -332,7 +333,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
@@ -378,7 +379,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
|
||||
@@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC
|
||||
}
|
||||
|
||||
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
|
||||
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
||||
gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
|
||||
q_stride, k_stride, v_stride, m_stride;
|
||||
|
||||
void init_indices()
|
||||
@@ -173,12 +173,19 @@ void init_indices()
|
||||
N = p.N;
|
||||
KV = p.KV;
|
||||
|
||||
i = gl_WorkGroupID.x;
|
||||
split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
// batch and split_k share gl_WorkGroupID.x
|
||||
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
} else if (p.gqa_ratio > 1) {
|
||||
i = 0;
|
||||
gqa_iq1 = gl_WorkGroupID.x;
|
||||
split_k_index = 0;
|
||||
} else {
|
||||
i = gl_WorkGroupID.x;
|
||||
gqa_iq1 = 0;
|
||||
split_k_index = 0;
|
||||
}
|
||||
|
||||
Tr = CEIL_DIV(N, Br);
|
||||
|
||||
@@ -90,7 +90,7 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
@@ -141,9 +141,9 @@ void main() {
|
||||
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
||||
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
||||
#endif
|
||||
uint32_t m_offset = 0;
|
||||
uint32_t m_offset = gqa_iq1*KV;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
@@ -370,7 +370,8 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
@@ -382,7 +383,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
@@ -428,7 +429,7 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
|
||||
@@ -111,7 +111,7 @@ void main() {
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
@@ -138,9 +138,9 @@ void main() {
|
||||
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
||||
}
|
||||
|
||||
uint32_t m_offset = 0;
|
||||
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
|
||||
if (p.nem2 != 1 || p.nem3 != 1) {
|
||||
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
|
||||
}
|
||||
|
||||
[[dont_unroll]]
|
||||
@@ -272,10 +272,11 @@ void main() {
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
||||
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
return;
|
||||
@@ -325,7 +326,7 @@ void main() {
|
||||
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
|
||||
@@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint D;
|
||||
uint N;
|
||||
uint ne1;
|
||||
uint ne2;
|
||||
uint ne3;
|
||||
uint k_num;
|
||||
uint sinks;
|
||||
@@ -24,15 +25,15 @@ void main() {
|
||||
// Each workgroup handles a row
|
||||
const uint n = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint iq3 = gl_WorkGroupID.z;
|
||||
const uint i2 = gl_WorkGroupID.z % p.ne2;
|
||||
const uint i3 = gl_WorkGroupID.z / p.ne2;
|
||||
|
||||
uint D = p.D;
|
||||
uint N = p.N;
|
||||
uint k_num = p.k_num;
|
||||
|
||||
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
|
||||
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
|
||||
uint lm_stride = N * 2;
|
||||
uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
|
||||
uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
|
||||
uint lm_stride = p.ne1 * 2;
|
||||
|
||||
// Compute the max m value for the row
|
||||
float m_max = -1.0/0.0;
|
||||
@@ -99,7 +100,7 @@ void main() {
|
||||
if (d < D) {
|
||||
float O = 0.0;
|
||||
[[unroll]] for (uint k = 0; k < k_num; ++k) {
|
||||
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
|
||||
uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
|
||||
float m = data_a[m_offset + k * lm_stride];
|
||||
O += exp(m - m_max) * data_a[o_offset];
|
||||
}
|
||||
@@ -115,6 +116,6 @@ void main() {
|
||||
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
|
||||
O = clamp(O, -FLT_MAX, FLT_MAX);
|
||||
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,8 @@ layout (push_constant) uniform parameter
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
uint ne11;
|
||||
uint expert_i1;
|
||||
uint nbi1;
|
||||
#else
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
@@ -43,7 +45,7 @@ uint expert_id;
|
||||
|
||||
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
#else
|
||||
const uint batch_idx = gl_GlobalInvocationID.y;
|
||||
#endif
|
||||
@@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
batch_idx_a = i03 * p.ne02 + i02;
|
||||
}
|
||||
#else
|
||||
expert_id = data_ids[expert_idx];
|
||||
expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
|
||||
#endif
|
||||
|
||||
a_offset =
|
||||
@@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||
#endif
|
||||
b_offset =
|
||||
#ifdef MUL_MAT_ID
|
||||
(expert_idx % p.ne11) * p.stride_b;
|
||||
(expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
|
||||
#else
|
||||
batch_idx * p.batch_stride_b;
|
||||
#endif
|
||||
d_offset =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * p.stride_d;
|
||||
expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
|
||||
#else
|
||||
batch_idx * p.batch_stride_d;
|
||||
#endif
|
||||
@@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
@@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
@@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
|
||||
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
|
||||
}
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
|
||||
const uint expert_idx = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
|
||||
const uint expert_i0 = gl_GlobalInvocationID.y;
|
||||
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
|
||||
}
|
||||
#else
|
||||
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
|
||||
|
||||
@@ -372,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty
|
||||
}
|
||||
|
||||
static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return true;
|
||||
/* while it resides in host memory, additional transformation is needed */
|
||||
return false;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
+9
-1
@@ -585,6 +585,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
|
||||
break;
|
||||
}
|
||||
|
||||
// check that the size of the tensor in bytes is representable
|
||||
if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) {
|
||||
GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n",
|
||||
__func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// calculate byte offsets given the tensor shape and type
|
||||
info.t.nb[0] = type_size;
|
||||
info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
|
||||
@@ -734,7 +742,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||
FILE * file = ggml_fopen(fname, "rb");
|
||||
|
||||
if (!file) {
|
||||
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
|
||||
GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
+3
-2
@@ -489,6 +489,7 @@ extern "C" {
|
||||
// - returns true if the parameters could be successfully modified to fit device memory
|
||||
// - this function is NOT thread safe because it modifies the global llama logger state
|
||||
// - only parameters that have the same value as in llama_default_model_params are modified
|
||||
// with the exception of the context size which is modified if and only if equal to 0
|
||||
LLAMA_API enum llama_params_fit_status llama_params_fit(
|
||||
const char * path_model,
|
||||
struct llama_model_params * mparams,
|
||||
@@ -1475,12 +1476,12 @@ extern "C" {
|
||||
/// @details Build a split GGUF final path for this chunk.
|
||||
/// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
|
||||
// Returns the split_path length.
|
||||
LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
|
||||
LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count);
|
||||
|
||||
/// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
|
||||
/// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
|
||||
// Returns the split_prefix length.
|
||||
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
|
||||
LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count);
|
||||
|
||||
// Print system information
|
||||
LLAMA_API const char * llama_print_system_info(void);
|
||||
|
||||
@@ -3,7 +3,7 @@ pytest~=8.3.3
|
||||
huggingface_hub>=0.34.0,<1.0
|
||||
matplotlib~=3.10.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.55.3
|
||||
openai~=2.14.0
|
||||
pandas~=2.2.3
|
||||
prometheus-client~=0.20.0
|
||||
requests~=2.32.3
|
||||
|
||||
@@ -29,7 +29,7 @@ LLAMA_BENCH_DB_FIELDS = [
|
||||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
"test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "n_cpu_moe"
|
||||
]
|
||||
|
||||
LLAMA_BENCH_DB_TYPES = [
|
||||
@@ -38,7 +38,7 @@ LLAMA_BENCH_DB_TYPES = [
|
||||
"TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT",
|
||||
"INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL",
|
||||
"TEXT", "INTEGER", "INTEGER", "REAL", "REAL", "INTEGER",
|
||||
]
|
||||
|
||||
# All test-backend-ops SQL fields
|
||||
@@ -59,7 +59,7 @@ assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
|
||||
|
||||
# Properties by which to differentiate results per commit for llama-bench:
|
||||
LLAMA_BENCH_KEY_PROPERTIES = [
|
||||
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
|
||||
"cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type",
|
||||
"n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
|
||||
"use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
|
||||
]
|
||||
|
||||
@@ -24,6 +24,7 @@ add_library(llama
|
||||
llama-kv-cache-iswa.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-hybrid-iswa.cpp
|
||||
llama-memory-recurrent.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
|
||||
@@ -2903,7 +2903,7 @@ void llama_context::opt_epoch_iter(
|
||||
};
|
||||
ctx_compute_opt = ggml_init(params);
|
||||
}
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
||||
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits());
|
||||
ggml_opt_alloc(opt_ctx, train);
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
+165
-17
@@ -7,6 +7,7 @@
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <cassert>
|
||||
@@ -22,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
if (ubatch->embd) {
|
||||
const int64_t n_embd = embd->ne[0];
|
||||
GGML_ASSERT(n_embd == embd->ne[0]);
|
||||
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd));
|
||||
@@ -32,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
|
||||
bool res = true;
|
||||
|
||||
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
|
||||
res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -510,6 +512,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
const auto * attn_ctx = mctx->get_attn();
|
||||
|
||||
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||
|
||||
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
||||
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
||||
|
||||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
const auto * attn_ctx = mctx->get_attn();
|
||||
|
||||
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
}
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||
|
||||
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
// set the inputs only for the active samplers in the current ubatch
|
||||
std::unordered_set<llama_seq_id> active_samplers;
|
||||
@@ -563,7 +635,8 @@ int64_t llm_graph_result::get_max_nodes() const {
|
||||
}
|
||||
|
||||
void llm_graph_result::reset() {
|
||||
t_tokens = nullptr;
|
||||
t_inp_tokens = nullptr;
|
||||
t_inp_embd = nullptr;
|
||||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
@@ -1267,17 +1340,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
// input embeddings with optional lora
|
||||
ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
const int64_t n_embd = hparams.n_embd_inp();
|
||||
const int64_t n_embd_inp = hparams.n_embd_inp();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
assert(n_embd_inp >= n_embd);
|
||||
|
||||
ggml_tensor * cur = nullptr;
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp);
|
||||
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
//cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
cb(inp->tokens, "inp_tokens", -1);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens);
|
||||
cb(inp->embd, "inp_embd", -1);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
// select one of the 2 inputs, based on the batch contents
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
||||
std::array<ggml_tensor *, 2> inps;
|
||||
|
||||
// token embeddings path (ubatch.token != nullptr)
|
||||
{
|
||||
auto & cur = inps[0];
|
||||
|
||||
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
||||
|
||||
@@ -1298,19 +1383,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpL_delta);
|
||||
}
|
||||
} else {
|
||||
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
|
||||
ggml_set_input(inp->embd);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// vector embeddings path (ubatch.embd != nullptr)
|
||||
{
|
||||
auto & cur = inps[1];
|
||||
|
||||
cur = inp->embd;
|
||||
}
|
||||
|
||||
assert(ggml_are_same_shape (inps[0], inps[1]));
|
||||
assert(ggml_are_same_stride(inps[0], inps[1]));
|
||||
|
||||
ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1);
|
||||
|
||||
if (n_embd_inp != n_embd) {
|
||||
cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0);
|
||||
}
|
||||
|
||||
res->t_inp_embd = cur;
|
||||
|
||||
// For Granite architecture
|
||||
if (hparams.f_embedding_scale != 0.0f) {
|
||||
cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale);
|
||||
}
|
||||
|
||||
cb(cur, "inp_embd", -1);
|
||||
cb(cur, "embd", -1);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
@@ -1409,7 +1511,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||
//}
|
||||
|
||||
const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp();
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc);
|
||||
ggml_set_input(cur);
|
||||
@@ -1494,6 +1596,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
v = ggml_transpose(ctx0, v);
|
||||
}
|
||||
|
||||
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
|
||||
if (v_mla) {
|
||||
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
}
|
||||
|
||||
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
@@ -2056,6 +2163,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
||||
|
||||
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||
|
||||
// build iswa attention input
|
||||
const auto * attn_ctx = mctx_cur->get_attn();
|
||||
|
||||
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_base()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||
|
||||
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
void llm_graph_context::build_dense_out(
|
||||
ggml_tensor * dense_2,
|
||||
ggml_tensor * dense_3) const {
|
||||
|
||||
+37
-3
@@ -24,6 +24,7 @@ class llama_kv_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
class llama_memory_hybrid_iswa_context;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@@ -105,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||
|
||||
class llm_graph_input_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_embd() = default;
|
||||
llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {}
|
||||
virtual ~llm_graph_input_embd() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -114,6 +115,8 @@ public:
|
||||
|
||||
ggml_tensor * tokens = nullptr; // I32 [n_batch]
|
||||
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
|
||||
|
||||
const int64_t n_embd = 0;
|
||||
};
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
@@ -397,6 +400,34 @@ public:
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid_iswa(
|
||||
const llama_cparams & cparams,
|
||||
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||
const llama_memory_hybrid_iswa_context * mctx) :
|
||||
inp_attn(std::move(inp_attn)),
|
||||
inp_rs(std::move(inp_rs)),
|
||||
cparams(cparams),
|
||||
mctx(mctx) { }
|
||||
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
|
||||
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||
|
||||
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
|
||||
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_memory_hybrid_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||
@@ -537,7 +568,7 @@ public:
|
||||
|
||||
virtual ~llm_graph_result() = default;
|
||||
|
||||
ggml_tensor * get_tokens() const { return t_tokens; }
|
||||
ggml_tensor * get_inp_tokens() const { return t_inp_tokens; }
|
||||
ggml_tensor * get_logits() const { return t_logits; }
|
||||
ggml_tensor * get_embd() const { return t_embd; }
|
||||
ggml_tensor * get_embd_pooled() const { return t_embd_pooled; }
|
||||
@@ -564,7 +595,8 @@ public:
|
||||
void set_params(const llm_graph_params & params);
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_tokens = nullptr;
|
||||
ggml_tensor * t_inp_tokens = nullptr;
|
||||
ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens]
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
@@ -881,6 +913,8 @@ struct llm_graph_context {
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
|
||||
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
|
||||
|
||||
//
|
||||
// pooling
|
||||
//
|
||||
|
||||
@@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
@@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
||||
n_rot, n_head_kv, get_size()*n_stream,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa
|
||||
//
|
||||
|
||||
llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn,
|
||||
const layer_filter_cb & filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_iswa(
|
||||
model,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
offload,
|
||||
swa_full,
|
||||
unified,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
n_ubatch,
|
||||
n_pad,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
nullptr
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
// follow the recurrent pattern for creating the ubatch splits
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||
// for simplicity, we always use sequential equal split for now
|
||||
ubatch = balloc.split_equal(n_ubatch, true);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
// prepare the attention cache (iswa version returns both base and swa slot infos)
|
||||
auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
|
||||
if (sinfos_swa.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||
} while(false);
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa::get_can_shift() const {
|
||||
// Shifting is trivially supported for recurrent
|
||||
return mem_attn->get_can_shift();
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::clear(bool data) {
|
||||
mem_attn->clear(data);
|
||||
mem_recr->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// Try removing from the recurrent cache first since it may fail. If it does
|
||||
// fail, the cache will not have been mutated.
|
||||
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||
return false;
|
||||
}
|
||||
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
|
||||
mem_attn->seq_keep(seq_id);
|
||||
mem_recr->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the min of the total cache is the max of the two caches' min values
|
||||
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// the max of the total cache is the min of the two caches' max values
|
||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
|
||||
for (const auto & buft_size : mem_recr->memory_breakdown()) {
|
||||
mb[buft_size.first] += buft_size.second;
|
||||
}
|
||||
return mb;
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
mem_attn->state_write(io, seq_id, flags);
|
||||
mem_recr->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
mem_attn->state_read(io, seq_id, flags);
|
||||
mem_recr->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
|
||||
return mem_recr.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa_context
|
||||
//
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
|
||||
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_attn->next();
|
||||
ctx_recr->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_iswa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_attn->apply();
|
||||
res = res & ctx_recr->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
|
||||
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
|
||||
// support models where each layer may be either attention-based (with SWA support) or recurrent
|
||||
|
||||
class llama_memory_hybrid_iswa : public llama_memory_i {
|
||||
public:
|
||||
llama_memory_hybrid_iswa(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
bool unified,
|
||||
/* layer filters */
|
||||
const layer_filter_cb & filter_attn = nullptr,
|
||||
const layer_filter_cb & filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid_iswa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_iswa * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
|
||||
|
||||
// init update
|
||||
explicit llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// init success
|
||||
llama_memory_hybrid_iswa_context(
|
||||
llama_memory_hybrid_iswa * mem,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_iswa_context() = default;
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_iswa_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_iswa_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_attn;
|
||||
const llama_memory_context_ptr ctx_recr;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
+45
-18
@@ -8,6 +8,7 @@
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
@@ -1713,7 +1714,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
|
||||
// for compatibility with existing DeepSeek V2 and V2.5 GGUFs
|
||||
// that have no expert_gating_func model parameter set
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
||||
if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) {
|
||||
// GLM 4.7 Lite
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
|
||||
} else {
|
||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
||||
}
|
||||
}
|
||||
|
||||
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
|
||||
@@ -7523,23 +7529,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
};
|
||||
}
|
||||
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_n_pad */ 1,
|
||||
/* attn_n_swa */ hparams.n_swa,
|
||||
/* attn_swa_type */ hparams.swa_type,
|
||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
// Use hybrid-iswa for hybrid models with SWA
|
||||
res = new llama_memory_hybrid_iswa(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_swa_full */ params.swa_full,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_n_ubatch */ cparams.n_ubatch,
|
||||
/* attn_n_pad */ 1,
|
||||
/* recurrent_type_r */ GGML_TYPE_F32,
|
||||
/* recurrent_type_s */ GGML_TYPE_F32,
|
||||
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
} else {
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_n_pad */ 1,
|
||||
/* attn_n_swa */ hparams.n_swa,
|
||||
/* attn_swa_type */ hparams.swa_type,
|
||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
}
|
||||
} else {
|
||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||
|
||||
|
||||
+53
-56
@@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||
++qs.i_ffn_up;
|
||||
}
|
||||
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// IK: let's remove this, else Q2_K is almost the same as Q3_K_S
|
||||
//else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
|
||||
//}
|
||||
// This can be used to reduce the size of the Q5_K_S model.
|
||||
// The associated PPL increase is fully in line with the size reduction
|
||||
//else {
|
||||
// if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
|
||||
//}
|
||||
bool convert_incompatible_tensor = false;
|
||||
{
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t ny = tensor->ne[1];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (nx % qk_k != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
} else {
|
||||
++qs.n_k_quantized;
|
||||
}
|
||||
}
|
||||
|
||||
if (convert_incompatible_tensor) {
|
||||
switch (new_type) {
|
||||
case GGML_TYPE_TQ1_0:
|
||||
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
|
||||
return new_type;
|
||||
}
|
||||
|
||||
@@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||
int fallback = qs.n_fallback;
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
// unless the user specifies a type, and the tensor geometry will not require fallback quantisation
|
||||
if (params->tensor_types && qs.n_fallback - fallback == 0) {
|
||||
// if the user provided tensor types - use those
|
||||
bool manual = false;
|
||||
if (params->tensor_types) {
|
||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||
const std::string tensor_name(tensor->name);
|
||||
for (const auto & [tname, qtype] : tensor_types) {
|
||||
if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) {
|
||||
if (qtype != new_type) {
|
||||
LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type));
|
||||
LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype));
|
||||
new_type = qtype; // if two or more types are specified for the same tensor, the last match wins
|
||||
manual = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if not manual - use the standard logic for choosing the quantization type based on the selected mixture
|
||||
if (!manual) {
|
||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||
}
|
||||
|
||||
// incompatible tensor shapes are handled here - fallback to a compatible type
|
||||
{
|
||||
bool convert_incompatible_tensor = false;
|
||||
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t ny = tensor->ne[1];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (nx % qk_k != 0) {
|
||||
LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
|
||||
convert_incompatible_tensor = true;
|
||||
} else {
|
||||
++qs.n_k_quantized;
|
||||
}
|
||||
|
||||
if (convert_incompatible_tensor) {
|
||||
switch (new_type) {
|
||||
case GGML_TYPE_TQ1_0:
|
||||
case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
|
||||
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
|
||||
case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break;
|
||||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
|
||||
new_type = params->token_embedding_type;
|
||||
|
||||
+50
-16
@@ -311,8 +311,12 @@ static void llama_params_fit_impl(
|
||||
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
|
||||
__func__, hp_nct, n_ctx_min);
|
||||
if (n_ctx_min == UINT32_MAX) {
|
||||
LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
|
||||
__func__, hp_nct, n_ctx_min);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
|
||||
@@ -1091,25 +1095,55 @@ int32_t llama_chat_apply_template(
|
||||
// model split
|
||||
//
|
||||
|
||||
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
|
||||
int32_t llama_split_path(
|
||||
char * split_path,
|
||||
size_t maxlen,
|
||||
const char * path_prefix,
|
||||
int32_t split_no,
|
||||
int32_t split_count) {
|
||||
|
||||
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
|
||||
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
|
||||
return strlen(split_path);
|
||||
|
||||
const int written = snprintf(
|
||||
split_path,
|
||||
maxlen,
|
||||
SPLIT_PATH_FORMAT,
|
||||
path_prefix,
|
||||
split_no + 1,
|
||||
split_count
|
||||
);
|
||||
|
||||
if (written < 0 || (size_t) written >= maxlen) {
|
||||
return 0;
|
||||
}
|
||||
return 0;
|
||||
|
||||
return (int32_t) written;
|
||||
}
|
||||
|
||||
int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
|
||||
std::string str_split_path(split_path);
|
||||
char postfix[32];
|
||||
snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
|
||||
std::string str_postfix(postfix);
|
||||
int32_t llama_split_prefix(
|
||||
char * split_prefix,
|
||||
size_t maxlen,
|
||||
const char * split_path,
|
||||
int32_t split_no,
|
||||
int32_t split_count) {
|
||||
|
||||
// check if split_prefix ends with postfix
|
||||
int size_prefix = str_split_path.size() - str_postfix.size();
|
||||
if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
|
||||
snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
|
||||
return size_prefix;
|
||||
const std::string str_split_path(split_path);
|
||||
|
||||
char postfix[32];
|
||||
snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
|
||||
|
||||
const std::string str_postfix(postfix);
|
||||
if (str_split_path.size() <= str_postfix.size()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const size_t size_prefix = str_split_path.size() - str_postfix.size();
|
||||
|
||||
if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
|
||||
const size_t copy_len = std::min(size_prefix + 1, maxlen);
|
||||
snprintf(split_prefix, copy_len, "%s", split_path);
|
||||
|
||||
return (int32_t) size_prefix;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
@@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
|
||||
@@ -245,12 +245,12 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
|
||||
// equivalent to get_per_layer_inputs() in python code
|
||||
// output shape: [n_embd_altup, n_layer, n_tokens]
|
||||
ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
|
||||
ggml_tensor * inp_per_layer;
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_input(inp->tokens);
|
||||
res->t_tokens = inp->tokens;
|
||||
res->t_inp_tokens = inp->tokens;
|
||||
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
||||
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
||||
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
|
||||
|
||||
@@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -67,7 +67,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
// compute Q and K and (optionally) RoPE them
|
||||
// compute Q and K
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
|
||||
@@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
@@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||
|
||||
if (ubatch.embd) {
|
||||
// Image input: split main embd and deepstack embds
|
||||
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||
}
|
||||
inpL = inpL_main;
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
@@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||
if (il < (int) n_deepstack_layers) {
|
||||
ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
|
||||
cur = ggml_add(ctx0, cur, ds);
|
||||
cb(cur, "deepstack_out", il);
|
||||
}
|
||||
|
||||
|
||||
+5
-14
@@ -2,7 +2,8 @@
|
||||
|
||||
llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const size_t n_deepstack_layers = hparams.n_deepstack_layers;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
@@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr);
|
||||
|
||||
if (ubatch.embd) {
|
||||
// Image input: split main embd and deepstack embds
|
||||
ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
|
||||
for (size_t i = 0; i < n_deepstack_layers; i++) {
|
||||
deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
|
||||
}
|
||||
inpL = inpL_main;
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
@@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
if (ubatch.embd && (size_t)il < n_deepstack_layers) {
|
||||
cur = ggml_add(ctx0, cur, deepstack_features[il]);
|
||||
if (il < (int) n_deepstack_layers) {
|
||||
ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
|
||||
cur = ggml_add(ctx0, cur, ds);
|
||||
cb(cur, "deepstack_out", il);
|
||||
}
|
||||
|
||||
|
||||
@@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
|
||||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
ggml_tensor * v = nullptr;
|
||||
if (hsk_padded == 576 && hsv_padded == 512) {
|
||||
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
|
||||
|
||||
// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
|
||||
// for more info:
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/13435
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18986
|
||||
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
} else {
|
||||
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
}
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
@@ -8460,6 +8472,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
|
||||
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
||||
|
||||
test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
||||
|
||||
for (int kv : { 4096, 8192, 16384, }) {
|
||||
for (int hs : { 64, 128, }) {
|
||||
for (int nr : { 1, 4, }) {
|
||||
|
||||
+114
-126
@@ -54,113 +54,109 @@ static void assert_throws(const std::function<void()> & fn, const std::string &
|
||||
static void test_reasoning() {
|
||||
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
});
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ true,
|
||||
/* .thinking_forced_open = */ true,
|
||||
});
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = true;
|
||||
params.thinking_forced_open = true;
|
||||
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
|
||||
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||
assert_equals("<think>Cogito</think>", builder.result().content);
|
||||
assert_equals("Ergo sum", builder.consume_rest());
|
||||
}
|
||||
{
|
||||
const std::string variant("content_only_inline_think");
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .parse_tool_calls = */ false,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = false;
|
||||
const std::string input = "<think>Pense</think>Bonjour";
|
||||
auto msg = common_chat_parse(input, false, syntax);
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("Pense"), msg.reasoning_content);
|
||||
assert_equals(variant, std::string("Bonjour"), msg.content);
|
||||
}
|
||||
{
|
||||
const std::string variant("llama_3_inline_think");
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .parse_tool_calls = */ false,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = false;
|
||||
const std::string input = "<think>Plan</think>Réponse";
|
||||
auto msg = common_chat_parse(input, false, syntax);
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("Plan"), msg.reasoning_content);
|
||||
assert_equals(variant, std::string("Réponse"), msg.content);
|
||||
}
|
||||
// Test DeepSeek V3.1 parsing - reasoning content followed by "</think>" and then regular content
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("deepseek_v3_1_reasoning_format_deepseek");
|
||||
common_chat_msg_parser builder("REASONING</think>ok", /* is_partial= */ false, syntax);
|
||||
common_chat_msg_parser builder("REASONING</think>ok", /* is_partial= */ false, params);
|
||||
assert_equals(variant, true, builder.try_parse_reasoning("<think>", "</think>"));
|
||||
assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content);
|
||||
assert_equals(variant, std::string("ok"), builder.consume_rest());
|
||||
}
|
||||
// Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "</think>" and then regular content
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("deepseek_v3_1_reasoning_format_none");
|
||||
const std::string input = "REASONING</think>ok";
|
||||
auto msg = common_chat_parse(input, false, syntax);
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals(variant, std::string("REASONING</think>ok"), msg.content);
|
||||
assert_equals(variant, std::string(""), msg.reasoning_content);
|
||||
}
|
||||
@@ -256,15 +252,14 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
|
||||
// variant: happy path for when it works as the model card says it should
|
||||
const std::string variant("simple");
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string input = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto msg = common_chat_parse(input, false, syntax);
|
||||
auto msg = common_chat_parse(input, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, msg.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name);
|
||||
// JSON arguments are dumped without spaces
|
||||
@@ -274,16 +269,15 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
|
||||
// variant: simple + thinking open
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("simple_thinking");
|
||||
const std::string in = "REASONING</think><|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
|
||||
@@ -292,16 +286,15 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
}
|
||||
// variant: simple + multiple tool calls
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("simple_multiple_tool_calls");
|
||||
const std::string in = "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 2, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments);
|
||||
@@ -314,16 +307,15 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
|
||||
// variant: thinking forced open + tool call in reasoning content
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING</think><|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
|
||||
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
|
||||
@@ -336,16 +328,15 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
// to make tool calls in reasoning content according to the model card, but it does sometimes, so
|
||||
// add the reasoning content as regular content and parse the tool calls.
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals(variant, std::string("REASONING"), m.content);
|
||||
assert_equals(variant, std::string(""), m.reasoning_content);
|
||||
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
|
||||
@@ -355,16 +346,15 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
|
||||
// variant: thinking forced open + tool call in reasoning content + no closing think + partial
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial");
|
||||
const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>";
|
||||
auto m = common_chat_parse(in, /* is_partial= */ true, syntax);
|
||||
auto m = common_chat_parse(in, /* is_partial= */ true, params);
|
||||
assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"), m.reasoning_content);
|
||||
assert_equals(variant, std::string(""), m.content);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
@@ -372,32 +362,30 @@ static void test_deepseek_v3_1_tool_calls() {
|
||||
|
||||
// variant: thinking not forced open + reasoning + regular content + no tool calls
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ true,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = true;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls");
|
||||
const std::string in = "REASONING</think>CONTENT";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("CONTENT"), m.content);
|
||||
assert_equals(variant, std::string("REASONING"), m.reasoning_content);
|
||||
}
|
||||
// variant: thinking not forced open + missing reasoning + no tool calls
|
||||
{
|
||||
common_chat_syntax syntax = {
|
||||
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
/* .parse_tool_calls = */ true,
|
||||
};
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
params.parse_tool_calls = true;
|
||||
const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls");
|
||||
const std::string in = "CONTENT";
|
||||
auto m = common_chat_parse(in, false, syntax);
|
||||
auto m = common_chat_parse(in, false, params);
|
||||
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
|
||||
assert_equals(variant, std::string("CONTENT"), m.content);
|
||||
assert_equals(variant, std::string(""), m.reasoning_content);
|
||||
|
||||
@@ -616,15 +616,15 @@ void test_command7_parser_compare(testing & t) {
|
||||
|
||||
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
|
||||
// Original common_chat_combinator_parser taken from chat.cpp
|
||||
common_chat_parser_params params;
|
||||
params.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
params.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
params.reasoning_in_content = false;
|
||||
params.thinking_forced_open = false;
|
||||
common_chat_msg_parser builder(
|
||||
input,
|
||||
/* .is_partial = */ need_more_input,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
|
||||
/* .reasoning_in_content = */ false,
|
||||
/* .thinking_forced_open = */ false,
|
||||
}
|
||||
params
|
||||
);
|
||||
|
||||
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
|
||||
|
||||
+246
-224
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user