mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-28 08:37:46 +02:00
Compare commits
81 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7049736b2d | |||
| 01d2bdc2bc | |||
| 56fc38b965 | |||
| 1fb9504eb7 | |||
| 3f750f8d76 | |||
| c515fc5771 | |||
| f9bc66c3eb | |||
| a31cf36ad9 | |||
| 81d54bbfd5 | |||
| c7be9febcb | |||
| 8415f61e23 | |||
| 2c301e91ab | |||
| 4b2dae383d | |||
| 41aac5c69b | |||
| a2fba89a42 | |||
| 20cc625edc | |||
| 11f0af5504 | |||
| a3cb04744f | |||
| 4a8fbe0a5e | |||
| 31d0ff1869 | |||
| 97870e6497 | |||
| 477a66b035 | |||
| e60f01d941 | |||
| 81086cd6a3 | |||
| 68ee98ae18 | |||
| cdb6da468c | |||
| 6d69ab3f26 | |||
| 1faa13a118 | |||
| 1deee0f8d4 | |||
| d00cbea63c | |||
| 8328fd4bae | |||
| 56b4795842 | |||
| 2c0d875ae6 | |||
| aa4711d369 | |||
| d80d6d2400 | |||
| b260213755 | |||
| e08db42595 | |||
| 12bbc3fa50 | |||
| 9d0882840e | |||
| d2ee056e1d | |||
| b2c08c9ec4 | |||
| 7fdd16b432 | |||
| 74b8fc17f9 | |||
| aeaf8a36f0 | |||
| df1b612e29 | |||
| 4e0388aa8a | |||
| ef4c5b87ea | |||
| c61ae20d05 | |||
| 0123ff38f5 | |||
| 0a319bb75e | |||
| 1d6092fc72 | |||
| 8ae32dc9ec | |||
| 3df2244df4 | |||
| c08002a198 | |||
| 3a002afafa | |||
| a23b9bdbd3 | |||
| 04e632a4aa | |||
| a80ff183ab | |||
| 1d49ca3759 | |||
| c5fef0fcea | |||
| ca71fb9b36 | |||
| 35266573b9 | |||
| 86df2c9ae4 | |||
| f39283960b | |||
| 898acba681 | |||
| e29acf74fe | |||
| 128d522c04 | |||
| f6dcda3900 | |||
| 606a73f531 | |||
| 946f71ed9a | |||
| 638d330246 | |||
| 84c8e305e8 | |||
| 2aaf0a2a20 | |||
| 0e1f838556 | |||
| ad126479c2 | |||
| 77233277c9 | |||
| e308efda8e | |||
| 136bda78c5 | |||
| 5113efd34c | |||
| d64c8104f0 | |||
| ef07a40906 |
@@ -128,10 +128,6 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
};
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \
|
||||
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
|
||||
substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \
|
||||
--replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";"
|
||||
'';
|
||||
|
||||
# With PR#6015 https://github.com/ggml-org/llama.cpp/pull/6015,
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
name: "Install exe"
|
||||
description: "Download and install exe"
|
||||
inputs:
|
||||
url:
|
||||
description: "URL of the exe installer"
|
||||
required: true
|
||||
args:
|
||||
description: "Installer arguments"
|
||||
required: true
|
||||
timeout:
|
||||
description: "Timeout (in ms)"
|
||||
required: false
|
||||
default: "600000"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install EXE
|
||||
shell: pwsh
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading Installer EXE"
|
||||
Invoke-WebRequest -Uri "${{ inputs.url }}" -OutFile "${env:RUNNER_TEMP}\temp-install.exe"
|
||||
write-host "Installing"
|
||||
$proc = Start-Process "${env:RUNNER_TEMP}\temp-install.exe" -ArgumentList '${{ inputs.args }}' -NoNewWindow -PassThru
|
||||
$completed = $proc.WaitForExit(${{ inputs.timeout }})
|
||||
if (-not $completed) {
|
||||
Write-Error "Installer timed out. Killing the process"
|
||||
$proc.Kill()
|
||||
exit 1
|
||||
}
|
||||
if ($proc.ExitCode -ne 0) {
|
||||
Write-Error "Installer failed with exit code $($proc.ExitCode)"
|
||||
exit 1
|
||||
}
|
||||
write-host "Completed installation"
|
||||
@@ -0,0 +1,20 @@
|
||||
name: "Linux - Setup SpacemiT Toolchain"
|
||||
description: "Setup SpacemiT Toolchain for Linux"
|
||||
inputs:
|
||||
path:
|
||||
description: "Installation path"
|
||||
required: true
|
||||
version:
|
||||
description: "SpacemiT toolchain version"
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup SpacemiT Toolchain
|
||||
id: setup
|
||||
uses: ./.github/actions/unarchive-tar
|
||||
with:
|
||||
url: https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ inputs.version }}.tar.xz
|
||||
path: ${{ inputs.path }}
|
||||
strip: 1
|
||||
@@ -0,0 +1,20 @@
|
||||
name: "Linux - Setup Vulkan SDK"
|
||||
description: "Setup Vulkan SDK for Linux"
|
||||
inputs:
|
||||
path:
|
||||
description: "Installation path"
|
||||
required: true
|
||||
version:
|
||||
description: "Vulkan SDK version"
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup Vulkan SDK
|
||||
id: setup
|
||||
uses: ./.github/actions/unarchive-tar
|
||||
with:
|
||||
url: https://sdk.lunarg.com/sdk/download/${{ inputs.version }}/linux/vulkan_sdk.tar.xz
|
||||
path: ${{ inputs.path }}
|
||||
strip: 1
|
||||
@@ -0,0 +1,27 @@
|
||||
name: "Unarchive tar"
|
||||
description: "Download and unarchive tar into directory"
|
||||
inputs:
|
||||
url:
|
||||
description: "URL of the tar archive"
|
||||
required: true
|
||||
path:
|
||||
description: "Directory to unarchive into"
|
||||
required: true
|
||||
type:
|
||||
description: "Compression type (tar option)"
|
||||
required: false
|
||||
default: "J"
|
||||
strip:
|
||||
description: "Strip components"
|
||||
required: false
|
||||
default: "0"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Unarchive into directory
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir -p ${{ inputs.path }}
|
||||
cd ${{ inputs.path }}
|
||||
curl --no-progress-meter ${{ inputs.url }} | tar -${{ inputs.type }}x --strip-components=${{ inputs.strip }}
|
||||
@@ -0,0 +1,15 @@
|
||||
name: "Windows - Setup ROCm"
|
||||
description: "Setup ROCm for Windows"
|
||||
inputs:
|
||||
version:
|
||||
description: "ROCm version"
|
||||
required: true
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup ROCm
|
||||
uses: ./.github/actions/install-exe
|
||||
with:
|
||||
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe
|
||||
args: -install
|
||||
@@ -0,0 +1,89 @@
|
||||
name: Build Actions Cache
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
schedule:
|
||||
- cron: '0 * * * *'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ubuntu-24-vulkan-cache:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
run: |
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Setup Vulkan SDK
|
||||
if: steps.cache-sdk.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-vulkan
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
version: ${{ env.VULKAN_SDK_VERSION }}
|
||||
|
||||
ubuntu-24-spacemit-cache:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
env:
|
||||
# Make sure this is in sync with build-linux-cross.yml
|
||||
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Setup SpacemiT Toolchain
|
||||
if: steps.cache-toolchain.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-spacemit
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
|
||||
|
||||
windows-2022-rocm-cache:
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
# Make sure this is in sync with build.yml
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Setup ROCm
|
||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/windows-setup-rocm
|
||||
with:
|
||||
version: ${{ env.HIPSDK_INSTALLER_VERSION }}
|
||||
@@ -258,31 +258,29 @@ jobs:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
env:
|
||||
# Make sure this is in sync with build-cache.yml
|
||||
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||
SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Cache Toolchain
|
||||
- name: Use SpacemiT Toolchain Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-spacemit-ime-cross-toolchain
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||
key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
|
||||
path: ./spacemit_toolchain
|
||||
key: spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Setup Toolchain
|
||||
if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
|
||||
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||
mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||
tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1
|
||||
rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz
|
||||
- name: Setup SpacemiT Toolchain
|
||||
if: steps.cache-toolchain.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-spacemit
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
version: ${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}
|
||||
export RISCV_ROOT_PATH=${PWD}/spacemit_toolchain
|
||||
cmake -B build -DLLAMA_CURL=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENMP=OFF \
|
||||
|
||||
+101
-39
@@ -97,7 +97,7 @@ jobs:
|
||||
ctest -L 'main|curl' --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-x64:
|
||||
runs-on: macos-13
|
||||
runs-on: macos-15-intel
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -387,6 +387,39 @@ jobs:
|
||||
cd build
|
||||
ctest -L main --verbose
|
||||
|
||||
ubuntu-24-cmake-vulkan-deb:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-24-cmake-vulkan-deb
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
ubuntu-24-cmake-vulkan:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
@@ -413,20 +446,19 @@ jobs:
|
||||
run: |
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Cache Vulkan SDK
|
||||
id: cache_vulkan_sdk
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Install Vulkan SDK
|
||||
if: steps.cache_vulkan_sdk.outputs.cache-hit != 'true'
|
||||
id: vulkan_sdk_install
|
||||
run: |
|
||||
mkdir -p vulkan_sdk
|
||||
cd vulkan_sdk
|
||||
curl --no-progress-meter https://sdk.lunarg.com/sdk/download/latest/linux/vulkan_sdk.tar.xz | tar -Jx --strip-components=1
|
||||
- name: Setup Vulkan SDK
|
||||
if: steps.cache-sdk.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-vulkan
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
version: ${{ env.VULKAN_SDK_VERSION }}
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
@@ -445,8 +477,8 @@ jobs:
|
||||
# This is using llvmpipe and runs slower than other backends
|
||||
ctest -L main --verbose --timeout 4200
|
||||
|
||||
ubuntu-22-cmake-webgpu:
|
||||
runs-on: ubuntu-22.04
|
||||
ubuntu-24-cmake-webgpu:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -456,16 +488,34 @@ jobs:
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-22-cmake-webgpu
|
||||
key: ubuntu-24-cmake-webgpu
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Vulkan SDK Dependencies
|
||||
id: vulkan-depends
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
|
||||
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
|
||||
sudo add-apt-repository -y ppa:kisak/kisak-mesa
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev
|
||||
sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
run: |
|
||||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Setup Vulkan SDK
|
||||
if: steps.cache-sdk.outputs.cache-hit != 'true'
|
||||
uses: ./.github/actions/linux-setup-vulkan
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
version: ${{ env.VULKAN_SDK_VERSION }}
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
@@ -1111,6 +1161,7 @@ jobs:
|
||||
env:
|
||||
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||
ROCM_VERSION: "6.4.2"
|
||||
# Make sure this is in sync with build-cache.yml
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
|
||||
steps:
|
||||
@@ -1125,33 +1176,18 @@ jobs:
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
- name: Use ROCm Installation Cache
|
||||
uses: actions/cache@v4
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
||||
- name: Install ROCm
|
||||
- name: Setup ROCm
|
||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
||||
id: depends
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "Downloading AMD HIP SDK Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP SDK"
|
||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||
$completed = $proc.WaitForExit(600000)
|
||||
if (-not $completed) {
|
||||
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
|
||||
$proc.Kill()
|
||||
exit 1
|
||||
}
|
||||
if ($proc.ExitCode -ne 0) {
|
||||
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
|
||||
exit 1
|
||||
}
|
||||
write-host "Completed AMD HIP SDK installation"
|
||||
uses: ./.github/actions/windows-setup-rocm
|
||||
with:
|
||||
version: ${{ env.HIPSDK_INSTALLER_VERSION }}
|
||||
|
||||
- name: Verify ROCm
|
||||
id: verify
|
||||
@@ -1512,3 +1548,29 @@ jobs:
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-arm64-cpu-kleidiai:
|
||||
runs-on: ubuntu-22.04-arm
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-kleidiai
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libcurl4-openssl-dev
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ jobs:
|
||||
name: llama-bin-macos-arm64.zip
|
||||
|
||||
macOS-x64:
|
||||
runs-on: macos-13
|
||||
runs-on: macos-15-intel
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
+2
-1
@@ -2,7 +2,7 @@
|
||||
# multiplie collaborators per item can be specified
|
||||
|
||||
/.devops/*.Dockerfile @ngxson
|
||||
/.github/actions/ @slaren
|
||||
/.github/actions/ @slaren @CISC
|
||||
/.github/workflows/ @CISC
|
||||
/.github/workflows/release.yml @slaren
|
||||
/.github/workflows/winget.yml @slaren
|
||||
@@ -70,6 +70,7 @@
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
/ggml/src/ggml-threading.* @ggerganov @slaren
|
||||
/ggml/src/ggml-vulkan/ @0cc4m
|
||||
/ggml/src/ggml-webgpu/ @reeselevine
|
||||
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml.c @ggerganov @slaren
|
||||
/ggml/src/ggml.cpp @ggerganov @slaren
|
||||
|
||||
@@ -22,6 +22,9 @@
|
||||
# # with MUSA support
|
||||
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with KLEIDIAI support
|
||||
# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
|
||||
if [ -z "$2" ]; then
|
||||
echo "usage: $0 <output-dir> <mnt-dir>"
|
||||
@@ -115,6 +118,34 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
|
||||
fi
|
||||
|
||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||
echo ">>===== Enabling KleidiAI support"
|
||||
|
||||
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
|
||||
CPU=""
|
||||
|
||||
for cpu in "${CANDIDATES[@]}"; do
|
||||
if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then
|
||||
CPU="$cpu"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$CPU" ]; then
|
||||
echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ">>===== Using ARM baseline: ${CPU}"
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_KLEIDIAI=ON \
|
||||
-DGGML_CPU_AARCH64=ON \
|
||||
-DGGML_CPU_ARM_ARCH=${CPU} \
|
||||
-DBUILD_SHARED_LIBS=OFF"
|
||||
fi
|
||||
|
||||
## helpers
|
||||
|
||||
# download a file if it does not exist or if it is outdated
|
||||
@@ -512,12 +543,7 @@ function gg_run_rerank_tiny {
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json
|
||||
|
||||
gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json
|
||||
gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.json
|
||||
|
||||
path_models="../models-mnt/rerank-tiny"
|
||||
|
||||
|
||||
+189
-152
@@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) {
|
||||
if (!rpc_reg) {
|
||||
throw std::invalid_argument("failed to find RPC backend");
|
||||
}
|
||||
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
|
||||
if (!ggml_backend_rpc_add_device_fn) {
|
||||
throw std::invalid_argument("failed to find RPC device add function");
|
||||
typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
|
||||
ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
|
||||
if (!ggml_backend_rpc_add_server_fn) {
|
||||
throw std::invalid_argument("failed to find RPC add server function");
|
||||
}
|
||||
for (const auto & server : rpc_servers) {
|
||||
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
|
||||
if (dev) {
|
||||
ggml_backend_device_register(dev);
|
||||
} else {
|
||||
throw std::invalid_argument("failed to register RPC device");
|
||||
}
|
||||
auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
|
||||
ggml_backend_register(reg);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1932,13 +1928,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--swa-checkpoints"}, "N",
|
||||
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
|
||||
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
|
||||
string_format("max number of context checkpoints to create per slot (default: %d)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
|
||||
[](common_params & params, int value) {
|
||||
params.n_swa_checkpoints = value;
|
||||
params.n_ctx_checkpoints = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--cache-ram", "-cram"}, "N",
|
||||
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
|
||||
[](common_params & params, int value) {
|
||||
params.cache_ram_mib = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
@@ -2588,6 +2592,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.no_extra_bufts = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_NO_REPACK"));
|
||||
add_opt(common_arg(
|
||||
{"--no-host"},
|
||||
"bypass host buffer allowing extra buffers to be used",
|
||||
[](common_params & params) {
|
||||
params.no_host = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_NO_HOST"));
|
||||
add_opt(common_arg(
|
||||
{"-ctk", "--cache-type-k"}, "TYPE",
|
||||
string_format(
|
||||
@@ -3347,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
@@ -3429,7 +3440,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content`\n"
|
||||
"- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`\n"
|
||||
"(default: auto)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||
@@ -3558,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"-v", "--verbose", "--log-verbose"},
|
||||
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
|
||||
@@ -3838,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
// model-specific
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-steps"}, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-visual"},
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-eps"}, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-alg-temp"}, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-block-length"}, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-cfg-scale"}, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-add-gumbel-noise"}, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
|
||||
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-wd", "--weight-decay"}, "WD",
|
||||
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-val-split", "--val-split"}, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-epochs", "--epochs"}, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
|
||||
// presets
|
||||
add_opt(common_arg(
|
||||
{"--tts-oute-default"},
|
||||
string_format("use default OuteTTS models (note: can download weights from the internet)"),
|
||||
@@ -3851,42 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-bge-small-en-default"},
|
||||
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
|
||||
{"--embd-gemma-default"},
|
||||
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
|
||||
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-e5-small-en-default"},
|
||||
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
|
||||
params.model.hf_file = "e5-small-v2-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-gte-small-default"},
|
||||
string_format("use default gte-small model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
|
||||
params.model.hf_file = "gte-small-q8_0.gguf";
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
|
||||
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
|
||||
params.port = 8011;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 2048;
|
||||
params.n_parallel = 32;
|
||||
params.n_ctx = 2048*params.n_parallel;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
@@ -3981,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-20b-default"},
|
||||
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
|
||||
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-120b-default"},
|
||||
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-block-length" }, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-cfg-scale" }, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-add-gumbel-noise" }, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--vision-gemma-4b-default"},
|
||||
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
|
||||
add_opt(
|
||||
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format(
|
||||
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
|
||||
(double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
|
||||
(double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{ "-wd", "--weight-decay" }, "WD",
|
||||
string_format(
|
||||
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
|
||||
(double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).",
|
||||
(double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
})
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
{"--vision-gemma-12b-default"},
|
||||
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
||||
+127
-15
@@ -3,9 +3,12 @@
|
||||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -166,6 +169,27 @@ void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||
}
|
||||
|
||||
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||
std::string pending_reasoning_prefix;
|
||||
|
||||
if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto set_reasoning_prefix = [&](size_t prefix_pos) {
|
||||
if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) {
|
||||
return;
|
||||
}
|
||||
if (prefix_pos + start_think.size() > input_.size()) {
|
||||
pending_reasoning_prefix.clear();
|
||||
return;
|
||||
}
|
||||
// Capture the exact literal that opened the reasoning section so we can
|
||||
// surface it back to callers. This ensures formats that force the
|
||||
// reasoning tag open (e.g. DeepSeek R1) retain their original prefix
|
||||
// instead of dropping it during parsing.
|
||||
pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size());
|
||||
};
|
||||
|
||||
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||
auto stripped_reasoning = string_strip(reasoning);
|
||||
if (stripped_reasoning.empty()) {
|
||||
@@ -178,28 +202,116 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
||||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||
}
|
||||
} else {
|
||||
if (!pending_reasoning_prefix.empty()) {
|
||||
add_reasoning_content(pending_reasoning_prefix);
|
||||
pending_reasoning_prefix.clear();
|
||||
}
|
||||
add_reasoning_content(stripped_reasoning);
|
||||
}
|
||||
};
|
||||
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||
if (auto res = try_find_literal(end_think)) {
|
||||
handle_reasoning(res->prelude, /* closed */ true);
|
||||
consume_spaces();
|
||||
return true;
|
||||
}
|
||||
auto rest = consume_rest();
|
||||
|
||||
const size_t saved_pos = pos_;
|
||||
const size_t saved_content_size = result_.content.size();
|
||||
const size_t saved_reasoning_size = result_.reasoning_content.size();
|
||||
|
||||
auto restore_state = [&]() {
|
||||
move_to(saved_pos);
|
||||
result_.content.resize(saved_content_size);
|
||||
result_.reasoning_content.resize(saved_reasoning_size);
|
||||
};
|
||||
|
||||
// Allow leading whitespace to be preserved as content when reasoning is present at the start
|
||||
size_t cursor = pos_;
|
||||
size_t whitespace_end = cursor;
|
||||
while (whitespace_end < input_.size() && std::isspace(static_cast<unsigned char>(input_[whitespace_end]))) {
|
||||
++whitespace_end;
|
||||
}
|
||||
|
||||
if (whitespace_end >= input_.size()) {
|
||||
restore_state();
|
||||
if (syntax_.thinking_forced_open) {
|
||||
auto rest = input_.substr(saved_pos);
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||
// if (!syntax_.thinking_forced_open) {
|
||||
// throw common_chat_msg_partial_exception(end_think);
|
||||
// }
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
cursor = whitespace_end;
|
||||
const size_t remaining = input_.size() - cursor;
|
||||
const size_t start_prefix = std::min(start_think.size(), remaining);
|
||||
const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0;
|
||||
|
||||
if (has_start_tag && start_prefix < start_think.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (has_start_tag) {
|
||||
if (whitespace_end > pos_) {
|
||||
add_content(input_.substr(pos_, whitespace_end - pos_));
|
||||
}
|
||||
set_reasoning_prefix(cursor);
|
||||
cursor += start_think.size();
|
||||
} else if (syntax_.thinking_forced_open) {
|
||||
cursor = whitespace_end;
|
||||
} else {
|
||||
restore_state();
|
||||
return false;
|
||||
}
|
||||
while (true) {
|
||||
if (cursor >= input_.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t end_pos = input_.find(end_think, cursor);
|
||||
if (end_pos == std::string::npos) {
|
||||
std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor);
|
||||
size_t partial_off = string_find_partial_stop(remaining_view, end_think);
|
||||
size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off;
|
||||
if (reasoning_end > cursor) {
|
||||
handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial());
|
||||
}
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (end_pos > cursor) {
|
||||
handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true);
|
||||
} else {
|
||||
handle_reasoning("", /* closed */ true);
|
||||
}
|
||||
|
||||
cursor = end_pos + end_think.size();
|
||||
|
||||
while (cursor < input_.size() && std::isspace(static_cast<unsigned char>(input_[cursor]))) {
|
||||
++cursor;
|
||||
}
|
||||
|
||||
const size_t next_remaining = input_.size() - cursor;
|
||||
if (next_remaining == 0) {
|
||||
move_to(cursor);
|
||||
return true;
|
||||
}
|
||||
|
||||
const size_t next_prefix = std::min(start_think.size(), next_remaining);
|
||||
if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) {
|
||||
if (next_prefix < start_think.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
set_reasoning_prefix(cursor);
|
||||
cursor += start_think.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
move_to(cursor);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
@@ -320,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
@@ -332,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
||||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
|
||||
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
||||
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
||||
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
||||
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
||||
@@ -984,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
|
||||
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
"[/THINK]",
|
||||
};
|
||||
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {
|
||||
{"type", "string"},
|
||||
{"const", function.at("name")},
|
||||
}},
|
||||
{"arguments", function.at("parameters")},
|
||||
{"id", {
|
||||
{"type", "string"},
|
||||
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments", "id"})},
|
||||
});
|
||||
});
|
||||
auto schema = json {
|
||||
{"type", "array"},
|
||||
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
||||
{"minItems", 1},
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
});
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
||||
data.preserved_tokens.push_back("[TOOL_CALLS]");
|
||||
} else {
|
||||
data.grammar_lazy = false;
|
||||
if (!inputs.json_schema.is_null()) {
|
||||
if (!inputs.grammar.empty()) {
|
||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
||||
}
|
||||
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
||||
} else {
|
||||
data.grammar = inputs.grammar;
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
@@ -994,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("[THINK]", "[/THINK]");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
|
||||
parse_prefixed_json_tool_call_array(builder, prefix);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
@@ -1336,6 +1408,8 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
@@ -2702,6 +2776,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
|
||||
return common_chat_params_init_magistral(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
@@ -2786,6 +2864,7 @@ common_chat_params common_chat_templates_apply(
|
||||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
@@ -2802,6 +2881,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
||||
common_chat_parse_mistral_nemo(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_MAGISTRAL:
|
||||
common_chat_parse_magistral(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
||||
common_chat_parse_llama_3_1(builder);
|
||||
break;
|
||||
|
||||
+4
-3
@@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
|
||||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_chat_msg_content_part> content_parts = {};
|
||||
std::vector<common_chat_tool_call> tool_calls = {};
|
||||
std::vector<common_chat_msg_content_part> content_parts;
|
||||
std::vector<common_chat_tool_call> tool_calls;
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
@@ -44,7 +44,7 @@ struct common_chat_msg {
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
auto id = tool_calls[i].id;
|
||||
@@ -101,6 +101,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
COMMON_CHAT_FORMAT_GENERIC,
|
||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||
COMMON_CHAT_FORMAT_MAGISTRAL,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||
|
||||
@@ -1133,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
mparams.use_mlock = params.use_mlock;
|
||||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
mparams.no_host = params.no_host;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
|
||||
+5
-3
@@ -378,7 +378,7 @@ struct common_params {
|
||||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
@@ -392,6 +392,7 @@ struct common_params {
|
||||
bool check_tensors = false; // validate tensor data
|
||||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
@@ -424,7 +425,8 @@ struct common_params {
|
||||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
@@ -432,7 +434,7 @@ struct common_params {
|
||||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
@@ -168,6 +169,47 @@ bool common_json_parse(
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
@@ -186,6 +228,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
@@ -205,6 +250,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
@@ -230,6 +278,9 @@ bool common_json_parse(
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
|
||||
+150
-14
@@ -93,13 +93,15 @@ class ModelBase:
|
||||
# Mistral format specifics
|
||||
is_mistral_format: bool = False
|
||||
disable_mistral_community_chat_template: bool = False
|
||||
sentence_transformers_dense_modules: bool = False
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False):
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
@@ -114,6 +116,7 @@ class ModelBase:
|
||||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
@@ -891,6 +894,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
|
||||
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
|
||||
res = "granite-docling"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -1325,6 +1331,7 @@ class MmprojModel(ModelBase):
|
||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
|
||||
|
||||
# load preprocessor config
|
||||
self.preprocessor_config = {}
|
||||
if not self.is_mistral_format:
|
||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
@@ -1347,7 +1354,8 @@ class MmprojModel(ModelBase):
|
||||
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
|
||||
|
||||
# vision config
|
||||
self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
|
||||
self.image_size = self.find_vparam(["image_size"])
|
||||
self.gguf_writer.add_vision_image_size(self.image_size)
|
||||
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
|
||||
@@ -2378,6 +2386,10 @@ class SmolVLMModel(MmprojModel):
|
||||
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
|
||||
# Add the preprocessor longest edge size
|
||||
preproc_image_size = self.preprocessor_config.get("size", {}).get("longest_edge", self.image_size)
|
||||
self.gguf_writer.add_vision_preproc_image_size(preproc_image_size)
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".embeddings." in name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
@@ -5260,6 +5272,53 @@ class Gemma3Model(TextModel):
|
||||
@ModelBase.register("Gemma3TextModel")
|
||||
class EmbeddingGemma(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
||||
module_paths = []
|
||||
dense_features_dims = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
# read modules.json to determine if model has Dense layers
|
||||
modules_file = self.dir_model / "modules.json"
|
||||
if modules_file.is_file():
|
||||
with open(modules_file, encoding="utf-8") as modules_json_file:
|
||||
mods = json.load(modules_json_file)
|
||||
for mod in mods:
|
||||
if mod["type"] == "sentence_transformers.models.Dense":
|
||||
mod_path = mod["path"]
|
||||
# check if model.safetensors file for Dense layer exists
|
||||
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
|
||||
if model_tensors_file.is_file():
|
||||
self.module_paths.append(mod_path)
|
||||
# read config.json of the Dense layer to get in/out features
|
||||
mod_conf_file = self.dir_model / mod_path / "config.json"
|
||||
if mod_conf_file.is_file():
|
||||
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
|
||||
mod_conf = json.load(mod_conf_json_file)
|
||||
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
|
||||
prefix = self._get_dense_prefix(mod_path)
|
||||
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
|
||||
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
from safetensors.torch import load_file
|
||||
module_paths = list(self.module_paths)
|
||||
for i, module_path in enumerate(module_paths):
|
||||
tensors_file = self.dir_model / module_path / "model.safetensors"
|
||||
local_tensors = load_file(tensors_file)
|
||||
tensor_name = self._get_dense_prefix(module_path)
|
||||
for name, local_tensor in local_tensors.items():
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
orig_name = name.replace("linear", tensor_name)
|
||||
name = self.map_tensor_name(orig_name)
|
||||
yield name, local_tensor.clone()
|
||||
|
||||
@staticmethod
|
||||
def _get_dense_prefix(module_path) -> str:
|
||||
"""Get the tensor name prefix for the Dense layer from module path."""
|
||||
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
|
||||
return tensor_name
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
@@ -5276,6 +5335,10 @@ class EmbeddingGemma(Gemma3Model):
|
||||
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
||||
f"instead of {self.hparams['sliding_window']}")
|
||||
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
for dense, dims in self.dense_features_dims.items():
|
||||
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
|
||||
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
|
||||
|
||||
self._try_set_pooling_type()
|
||||
|
||||
@@ -5903,20 +5966,12 @@ class Mamba2Model(TextModel):
|
||||
class JambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||
del tokenizer # unused
|
||||
|
||||
return "gpt-2"
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
# Using Jamba's tokenizer.json causes errors on model load
|
||||
# (something about "byte not found in vocab"),
|
||||
# but there's a working tokenizer.model
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Some Jamba models only have a tokenizer.json, which works.
|
||||
self._set_vocab_gpt2()
|
||||
self._set_vocab_llama_hf()
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||
@@ -8827,6 +8882,75 @@ class LFM2Model(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2MoeForCausalLM")
|
||||
class LFM2MoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2MOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# set num_key_value_heads only for attention layers
|
||||
self.hparams["num_key_value_heads"] = [
|
||||
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
|
||||
for layer_type in self.hparams["layer_types"]
|
||||
]
|
||||
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
|
||||
|
||||
# cache for experts weights for merging
|
||||
_experts_cache: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# conv op requires 2d tensor
|
||||
if 'conv.conv' in name:
|
||||
data_torch = data_torch.squeeze(1)
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
expert_cache[name] = data_torch
|
||||
expert_weights = ["w1", "w2", "w3"]
|
||||
|
||||
# not enough expert weights to merge
|
||||
if len(expert_cache) < n_experts * len(expert_weights):
|
||||
return []
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
for w_name in expert_weights:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{w_name}.weight"
|
||||
datas.append(expert_cache[ename])
|
||||
del expert_cache[ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
merged_name = f"layers.{bid}.feed_forward.experts.{w_name}.weight"
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
del self._experts_cache[bid]
|
||||
return tensors
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
assert not self._experts_cache
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2VlForConditionalGeneration")
|
||||
class LFM2VLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -9257,6 +9381,13 @@ def parse_args() -> argparse.Namespace:
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sentence-transformers-dense-modules", action="store_true",
|
||||
help=("Whether to include sentence-transformers dense modules."
|
||||
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
@@ -9319,9 +9450,13 @@ def main() -> None:
|
||||
if args.remote:
|
||||
hf_repo_id = args.model
|
||||
from huggingface_hub import snapshot_download
|
||||
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
|
||||
if args.sentence_transformers_dense_modules:
|
||||
# include sentence-transformers dense modules safetensors files
|
||||
allowed_patterns.append("*.safetensors")
|
||||
local_dir = snapshot_download(
|
||||
repo_id=hf_repo_id,
|
||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||
allow_patterns=allowed_patterns)
|
||||
dir_model = Path(local_dir)
|
||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||
else:
|
||||
@@ -9389,7 +9524,8 @@ def main() -> None:
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
||||
@@ -140,6 +140,7 @@ models = [
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
||||
+10
-8
@@ -31,7 +31,7 @@ Legend:
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -51,7 +51,7 @@ Legend:
|
||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
@@ -65,11 +65,11 @@ Legend:
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
@@ -92,9 +92,9 @@ Legend:
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -102,9 +102,11 @@ Legend:
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
+12095
-4249
File diff suppressed because it is too large
Load Diff
@@ -116,20 +116,39 @@ embedding-convert-model:
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/embedding/convert-model.sh
|
||||
|
||||
embedding-convert-model-st:
|
||||
$(call validate_embedding_model_path,embedding-convert-model-st)
|
||||
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/embedding/convert-model.sh -st
|
||||
|
||||
embedding-run-original-model:
|
||||
$(call validate_embedding_model_path,embedding-run-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
|
||||
USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
|
||||
./scripts/embedding/run-original-model.py \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
|
||||
$(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
|
||||
|
||||
embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
|
||||
embedding-run-original-model-st: embedding-run-original-model
|
||||
|
||||
embedding-run-converted-model:
|
||||
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
|
||||
$(if $(USE_POOLING),--pooling)
|
||||
|
||||
embedding-run-converted-model-st: USE_POOLING=1
|
||||
embedding-run-converted-model-st: embedding-run-converted-model
|
||||
|
||||
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
|
||||
@./scripts/embedding/compare-embeddings-logits.sh \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
|
||||
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
|
||||
@./scripts/embedding/compare-embeddings-logits.sh \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
|
||||
embedding-inspect-original-model:
|
||||
$(call validate_embedding_model_path,embedding-inspect-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
|
||||
|
||||
@@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary
|
||||
file containing logits which will be used for comparison with the converted
|
||||
model, and the other is a text file which allows for manual visual inspection.
|
||||
|
||||
#### Using SentenceTransformer with numbered layers
|
||||
For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
|
||||
03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
|
||||
|
||||
```console
|
||||
# Run original model with SentenceTransformer (applies all numbered layers)
|
||||
(venv) $ make embedding-run-original-model-st
|
||||
|
||||
# Run converted model with pooling enabled
|
||||
(venv) $ make embedding-run-converted-model-st
|
||||
```
|
||||
|
||||
This will use the SentenceTransformer library to load and run the model, which
|
||||
automatically applies all the numbered layers in the correct order. This is
|
||||
particularly useful when comparing with models that should include these
|
||||
additional transformation layers beyond just the base model output.
|
||||
|
||||
### Model conversion
|
||||
After updates have been made to [gguf-py](../../gguf-py) to add support for the
|
||||
new model the model can be converted to GGUF format using the following command:
|
||||
@@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits:
|
||||
(venv) $ make embedding-verify-logits
|
||||
```
|
||||
|
||||
For models with SentenceTransformer layers, use the `-st` verification target:
|
||||
```console
|
||||
(venv) $ make embedding-verify-logits-st
|
||||
```
|
||||
This convenience target automatically runs both the original model with SentenceTransformer
|
||||
and the converted model with pooling enabled, then compares the results.
|
||||
|
||||
### llama-server verification
|
||||
To verify that the converted model works with llama-server, the following
|
||||
command can be used:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
@@ -8,7 +11,10 @@
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
printf("\nexample usage:\n");
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
|
||||
printf("\n");
|
||||
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
|
||||
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
@@ -17,6 +23,8 @@ int main(int argc, char ** argv) {
|
||||
std::string prompt = "Hello, my name is";
|
||||
int ngl = 0;
|
||||
bool embedding_mode = false;
|
||||
bool pooling_enabled = false;
|
||||
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
{
|
||||
int i = 1;
|
||||
@@ -41,9 +49,13 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "-embd-mode") == 0) {
|
||||
embedding_mode = true;
|
||||
} else if (strcmp(argv[i], "-pooling") == 0) {
|
||||
pooling_enabled = true;
|
||||
} else if (strcmp(argv[i], "-embd-norm") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
try {
|
||||
embedding_mode = true;
|
||||
embd_norm = std::stoi(argv[++i]);
|
||||
} catch (...) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
@@ -112,7 +124,7 @@ int main(int argc, char ** argv) {
|
||||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
@@ -143,17 +155,27 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
float * logits;
|
||||
int n_logits;
|
||||
float * data_ptr;
|
||||
int data_size;
|
||||
const char * type;
|
||||
std::vector<float> embd_out;
|
||||
|
||||
if (embedding_mode) {
|
||||
logits = llama_get_embeddings(ctx);
|
||||
n_logits = llama_model_n_embd(model) * batch.n_tokens;
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
|
||||
const int n_embeddings = n_embd * n_embd_count;
|
||||
float * embeddings;
|
||||
type = "-embeddings";
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
const int n_embd_count = batch.n_tokens;
|
||||
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
embeddings = llama_get_embeddings_seq(ctx, 0);
|
||||
embd_out.resize(n_embeddings);
|
||||
printf("Normalizing embeddings using norm: %d\n", embd_norm);
|
||||
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
|
||||
embeddings = embd_out.data();
|
||||
} else {
|
||||
embeddings = llama_get_embeddings(ctx);
|
||||
}
|
||||
|
||||
printf("Embedding dimension: %d\n", n_embd);
|
||||
printf("\n");
|
||||
@@ -164,7 +186,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// Print first 3 values
|
||||
for (int i = 0; i < 3 && i < n_embd; i++) {
|
||||
printf("%9.6f ", logits[j * n_embd + i]);
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
|
||||
printf(" ... ");
|
||||
@@ -172,7 +194,7 @@ int main(int argc, char ** argv) {
|
||||
// Print last 3 values
|
||||
for (int i = n_embd - 3; i < n_embd; i++) {
|
||||
if (i >= 0) {
|
||||
printf("%9.6f ", logits[j * n_embd + i]);
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,27 +202,33 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Embeddings size: %d\n", n_logits);
|
||||
printf("Embeddings size: %d\n", n_embeddings);
|
||||
|
||||
data_ptr = embeddings;
|
||||
data_size = n_embeddings;
|
||||
} else {
|
||||
logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
n_logits = llama_vocab_n_tokens(vocab);
|
||||
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
const int n_logits = llama_vocab_n_tokens(vocab);
|
||||
type = "";
|
||||
printf("Vocab size: %d\n", n_logits);
|
||||
|
||||
data_ptr = logits;
|
||||
data_size = n_logits;
|
||||
}
|
||||
|
||||
std::filesystem::create_directory("data");
|
||||
|
||||
// Save logits to binary file
|
||||
// Save data to binary file
|
||||
char bin_filename[512];
|
||||
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
|
||||
printf("Saving logits to %s\n", bin_filename);
|
||||
printf("Saving data to %s\n", bin_filename);
|
||||
|
||||
FILE * f = fopen(bin_filename, "wb");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
fwrite(logits, sizeof(float), n_logits, f);
|
||||
fwrite(data_ptr, sizeof(float), data_size, f);
|
||||
fclose(f);
|
||||
|
||||
// Also save as text for debugging
|
||||
@@ -211,27 +239,27 @@ int main(int argc, char ** argv) {
|
||||
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
fprintf(f, "%d: %.6f\n", i, logits[i]);
|
||||
for (int i = 0; i < data_size; i++) {
|
||||
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
|
||||
}
|
||||
fclose(f);
|
||||
|
||||
if (!embedding_mode) {
|
||||
printf("First 10 logits: ");
|
||||
for (int i = 0; i < 10 && i < n_logits; i++) {
|
||||
printf("%.6f ", logits[i]);
|
||||
for (int i = 0; i < 10 && i < data_size; i++) {
|
||||
printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Last 10 logits: ");
|
||||
for (int i = n_logits - 10; i < n_logits; i++) {
|
||||
if (i >= 0) printf("%.6f ", logits[i]);
|
||||
for (int i = data_size - 10; i < data_size; i++) {
|
||||
if (i >= 0) printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
printf("Logits saved to %s\n", bin_filename);
|
||||
printf("Logits saved to %s\n", txt_filename);
|
||||
printf("Data saved to %s\n", bin_filename);
|
||||
printf("Data saved to %s\n", txt_filename);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
@@ -4,3 +4,4 @@ torchvision
|
||||
transformers
|
||||
huggingface-hub
|
||||
accelerate
|
||||
sentence-transformers
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
|
||||
set -e
|
||||
|
||||
# Parse command line arguments
|
||||
SENTENCE_TRANSFORMERS=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-st|--sentence-transformers)
|
||||
SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
|
||||
TYPE="${OUTTYPE:-f16}"
|
||||
@@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}"
|
||||
python ../../convert_hf_to_gguf.py --verbose \
|
||||
${EMBEDDING_MODEL_PATH} \
|
||||
--outfile ${CONVERTED_MODEL} \
|
||||
--outtype ${TYPE}
|
||||
--outtype ${TYPE} \
|
||||
${SENTENCE_TRANSFORMERS}
|
||||
|
||||
echo ""
|
||||
echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"
|
||||
|
||||
@@ -5,6 +5,7 @@ set -e
|
||||
# Parse command line arguments
|
||||
CONVERTED_MODEL=""
|
||||
PROMPTS_FILE=""
|
||||
USE_POOLING=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
@@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do
|
||||
PROMPTS_FILE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--pooling)
|
||||
USE_POOLING="1"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
CONVERTED_MODEL="$1"
|
||||
@@ -47,4 +52,8 @@ echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
# TODO: update logits.cpp to accept a --file/-f option for the prompt
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
|
||||
if [ -n "$USE_POOLING" ]; then
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
|
||||
else
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
|
||||
fi
|
||||
|
||||
@@ -14,6 +14,8 @@ unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
|
||||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
|
||||
parser.add_argument('--use-sentence-transformers', action='store_true',
|
||||
help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
|
||||
args = parser.parse_args()
|
||||
|
||||
def read_prompt_from_file(file_path):
|
||||
@@ -31,41 +33,52 @@ model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
|
||||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
# Determine if we should use SentenceTransformer
|
||||
use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# This can be used to override the sliding window size for manual testing. This
|
||||
# can be useful to verify the sliding window attention mask in the original model
|
||||
# and compare it with the converted .gguf model.
|
||||
if hasattr(config, 'sliding_window'):
|
||||
original_sliding_window = config.sliding_window
|
||||
#original_sliding_window = 6
|
||||
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
|
||||
|
||||
print(f"Using unreleased model: {unreleased_model_name}")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
class_name = f"{unreleased_model_name}Model"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path, config=config)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
if use_sentence_transformers:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
print("Using SentenceTransformer to apply all numbered layers")
|
||||
model = SentenceTransformer(model_path)
|
||||
tokenizer = model.tokenizer
|
||||
config = model[0].auto_model.config # type: ignore
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, config=config)
|
||||
print(f"Model class: {type(model)}")
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# This can be used to override the sliding window size for manual testing. This
|
||||
# can be useful to verify the sliding window attention mask in the original model
|
||||
# and compare it with the converted .gguf model.
|
||||
if hasattr(config, 'sliding_window'):
|
||||
original_sliding_window = config.sliding_window
|
||||
#original_sliding_window = 6
|
||||
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
|
||||
|
||||
print(f"Using unreleased model: {unreleased_model_name}")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
class_name = f"{unreleased_model_name}Model"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path, config=config)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, config=config)
|
||||
print(f"Model class: {type(model)}")
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
|
||||
# Verify the model is using the correct sliding window
|
||||
if hasattr(model.config, 'sliding_window'):
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}")
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
if not use_sentence_transformers:
|
||||
if hasattr(model.config, 'sliding_window'): # type: ignore
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
|
||||
@@ -75,34 +88,56 @@ if args.prompts_file:
|
||||
else:
|
||||
texts = ["Hello world today"]
|
||||
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoded)
|
||||
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
|
||||
if use_sentence_transformers:
|
||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
||||
all_embeddings = embeddings # Shape: [batch_size, hidden_size]
|
||||
|
||||
# Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior)
|
||||
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print(f"Hidden states shape: {hidden_states.shape}")
|
||||
print(f"All embeddings shape: {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
|
||||
else:
|
||||
# Standard approach: use base model output only
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE
|
||||
n_embd = all_embeddings.shape[1]
|
||||
n_embd_count = all_embeddings.shape[0]
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print() # Empty line to match C++ output
|
||||
outputs = model(**encoded)
|
||||
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
|
||||
|
||||
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
|
||||
print(f"Hidden states shape: {hidden_states.shape}")
|
||||
print(f"All embeddings shape: {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
|
||||
if len(all_embeddings.shape) == 1:
|
||||
n_embd = all_embeddings.shape[0] # type: ignore
|
||||
n_embd_count = 1
|
||||
all_embeddings = all_embeddings.reshape(1, -1)
|
||||
else:
|
||||
n_embd = all_embeddings.shape[1] # type: ignore
|
||||
n_embd_count = all_embeddings.shape[0] # type: ignore
|
||||
|
||||
print()
|
||||
|
||||
for j in range(n_embd_count):
|
||||
embedding = all_embeddings[j]
|
||||
@@ -120,29 +155,23 @@ with torch.no_grad():
|
||||
|
||||
print() # New line
|
||||
|
||||
print() # Final empty line to match C++ output
|
||||
print()
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
|
||||
|
||||
# Save all embeddings flattened (matching what embedding.cpp would save if it did)
|
||||
flattened_embeddings = all_embeddings.flatten()
|
||||
flattened_embeddings.astype(np.float32).tofile(bin_filename)
|
||||
|
||||
with open(txt_filename, "w") as f:
|
||||
f.write(f"# Model class: {model_name}\n")
|
||||
f.write(f"# Tokens: {token_strings}\n")
|
||||
f.write(f"# Shape: {all_embeddings.shape}\n")
|
||||
f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n")
|
||||
|
||||
idx = 0
|
||||
for j in range(n_embd_count):
|
||||
f.write(f"# Token {j} ({token_strings[j]}):\n")
|
||||
for i, value in enumerate(all_embeddings[j]):
|
||||
f.write(f"{j}_{i}: {value:.6f}\n")
|
||||
f.write("\n")
|
||||
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)")
|
||||
for value in all_embeddings[j]:
|
||||
f.write(f"{idx}: {value:.6f}\n")
|
||||
idx += 1
|
||||
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
|
||||
print("")
|
||||
print(f"Saved bin embeddings to: {bin_filename}")
|
||||
print(f"Saved txt embeddings to: {txt_filename}")
|
||||
|
||||
@@ -35,7 +35,11 @@ def cosine_similarity(a, b=None):
|
||||
|
||||
def load_embeddings_from_file(filename, n_tokens, n_embd):
|
||||
embeddings = np.fromfile(filename, dtype=np.float32)
|
||||
return embeddings.reshape(n_tokens, n_embd)
|
||||
# Check if this is pooled (single embedding) or per-token embeddings
|
||||
if len(embeddings) == n_embd:
|
||||
return embeddings.reshape(1, n_embd)
|
||||
else:
|
||||
return embeddings.reshape(n_tokens, n_embd)
|
||||
|
||||
def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
|
||||
np.set_printoptions(suppress=True, precision=6)
|
||||
@@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
|
||||
print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
|
||||
|
||||
n_tokens = len(tokens)
|
||||
is_pooled = python_emb.shape[0] == 1
|
||||
|
||||
# 1. Direct embedding comparison
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
# Check if the distance of each token embedding from the origin and compare
|
||||
# if the vectors are on the same "sphere". This does not tell us about
|
||||
# direction (meaning of the token embedding), just magnitude.
|
||||
for i in range(n_tokens):
|
||||
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
|
||||
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
|
||||
if is_pooled:
|
||||
print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
|
||||
|
||||
# 1. Direct embedding comparison for pooled embeddings
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
py_mag = np.linalg.norm(python_emb[0])
|
||||
cpp_mag = np.linalg.norm(cpp_emb[0])
|
||||
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
|
||||
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
|
||||
# 2. Cosine similarity between tokens within each model
|
||||
# Here we check the direction of token embeddings to see if the have the
|
||||
# same meaning (similarity). This is done by calculating cosine similarity
|
||||
# of a pair of token embeddings within each model.
|
||||
print(f"\n2. Within-Model Token Similarities:")
|
||||
print(" Python model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
# 2. Cross-model similarity for pooled embeddings
|
||||
print(f"\n2. Cross-Model Pooled Embedding Similarity:")
|
||||
sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
|
||||
print(f" Cosine similarity: {sim:.6f}")
|
||||
|
||||
print(" llama.cpp model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
return {
|
||||
'cross_model_similarities': [sim],
|
||||
'similarity_matrix_diff': np.array([[0.0]]),
|
||||
'max_diff': 0.0,
|
||||
'mean_diff': 0.0,
|
||||
'rms_diff': 0.0
|
||||
}
|
||||
else:
|
||||
# Original per-token comparison logic
|
||||
# 1. Direct embedding comparison
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
# Check if the distance of each token embedding from the origin and compare
|
||||
# if the vectors are on the same "sphere". This does not tell us about
|
||||
# direction (meaning of the token embedding), just magnitude.
|
||||
for i in range(n_tokens):
|
||||
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
|
||||
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
|
||||
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
|
||||
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
|
||||
# 3. Cross-model similarity (same token position)
|
||||
print(f"\n3. Cross-Model Same-Token Similarities:")
|
||||
for i in range(n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
|
||||
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
|
||||
# 2. Cosine similarity between tokens within each model
|
||||
# Here we check the direction of token embeddings to see if the have the
|
||||
# same meaning (similarity). This is done by calculating cosine similarity
|
||||
# of a pair of token embeddings within each model.
|
||||
print(f"\n2. Within-Model Token Similarities:")
|
||||
print(" Python model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
|
||||
# 4. Similarity matrix comparison
|
||||
print(f"\n4. Similarity Matrix Differences:")
|
||||
py_sim_matrix = cosine_similarity(python_emb)
|
||||
cpp_sim_matrix = cosine_similarity(cpp_emb)
|
||||
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
|
||||
print(" llama.cpp model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
|
||||
print(f" Max difference: {np.max(diff_matrix):.4f}")
|
||||
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
|
||||
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
|
||||
# 3. Cross-model similarity (same token position)
|
||||
print(f"\n3. Cross-Model Same-Token Similarities:")
|
||||
for i in range(n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
|
||||
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
|
||||
|
||||
return {
|
||||
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
|
||||
'similarity_matrix_diff': diff_matrix,
|
||||
'max_diff': np.max(diff_matrix),
|
||||
'mean_diff': np.mean(diff_matrix),
|
||||
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
||||
}
|
||||
# 4. Similarity matrix comparison
|
||||
print(f"\n4. Similarity Matrix Differences:")
|
||||
py_sim_matrix = cosine_similarity(python_emb)
|
||||
cpp_sim_matrix = cosine_similarity(cpp_emb)
|
||||
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
|
||||
|
||||
print(f" Max difference: {np.max(diff_matrix):.4f}")
|
||||
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
|
||||
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
|
||||
|
||||
return {
|
||||
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
|
||||
'similarity_matrix_diff': diff_matrix,
|
||||
'max_diff': np.max(diff_matrix),
|
||||
'mean_diff': np.mean(diff_matrix),
|
||||
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
||||
}
|
||||
|
||||
def read_prompt_from_file(file_path):
|
||||
try:
|
||||
|
||||
@@ -222,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
|
||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
|
||||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
|
||||
@@ -215,6 +215,8 @@ extern "C" {
|
||||
// Backend registry
|
||||
//
|
||||
|
||||
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
|
||||
|
||||
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
|
||||
|
||||
// Backend (reg) enumeration
|
||||
|
||||
@@ -7,26 +7,25 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 2
|
||||
#define RPC_PROTO_MAJOR_VERSION 3
|
||||
#define RPC_PROTO_MINOR_VERSION 0
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
|
||||
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
|
||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
||||
const char * cache_dir,
|
||||
size_t free_mem, size_t total_mem);
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1630,6 +1630,13 @@ extern "C" {
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API void ggml_soft_max_add_sinks(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * sinks);
|
||||
|
||||
@@ -145,6 +145,9 @@ endif()
|
||||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
|
||||
# in order to define _SC_PHYS_PAGES.
|
||||
else()
|
||||
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||
endif()
|
||||
|
||||
+16
-14
@@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
|
||||
free(alloc);
|
||||
}
|
||||
|
||||
static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
|
||||
size_t max_size = 0;
|
||||
for (int i = 0; i < alloc->n_chunks; i++) {
|
||||
max_size += alloc->chunks[i]->max_size;
|
||||
}
|
||||
return max_size;
|
||||
static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) {
|
||||
return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) {
|
||||
free(buf);
|
||||
}
|
||||
|
||||
static int ggml_vbuffer_n_chunks(struct vbuffer * buf) {
|
||||
int n = 0;
|
||||
while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
|
||||
return n;
|
||||
static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
|
||||
return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
|
||||
}
|
||||
|
||||
static size_t ggml_vbuffer_size(struct vbuffer * buf) {
|
||||
@@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
}
|
||||
}
|
||||
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
|
||||
|
||||
// even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
|
||||
if (new_size > cur_size || galloc->buffers[i] == NULL) {
|
||||
bool realloc = galloc->buffers[i] == NULL;
|
||||
size_t new_size = 0;
|
||||
for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
|
||||
size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
|
||||
size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
|
||||
new_size += new_chunk_size;
|
||||
if (new_chunk_size > cur_chunk_size) {
|
||||
realloc = true;
|
||||
}
|
||||
}
|
||||
if (realloc) {
|
||||
#ifndef NDEBUG
|
||||
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
|
||||
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
|
||||
#endif
|
||||
|
||||
|
||||
@@ -209,9 +209,6 @@ extern "C" {
|
||||
void * context;
|
||||
};
|
||||
|
||||
// Internal backend registry API
|
||||
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
|
||||
|
||||
// Add backend dynamic loading support to the backend
|
||||
|
||||
// Initialize the backend
|
||||
|
||||
+100
-107
@@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
|
||||
unary_op(ctx, acl_src0, acl_dst);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
|
||||
if(src1)
|
||||
ggml_cann_release_resources(ctx, acl_src1);
|
||||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -894,14 +892,13 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get or expand a cached float32 tensor filled with a scalar value.
|
||||
* @brief Get or expand a cached tensor filled with a scalar value.
|
||||
*
|
||||
* This function manages cached device memory for float32 tensors. If the current
|
||||
* This function manages cached device memory for tensors. If the current
|
||||
* cache size is insufficient for the requested tensor shape, the old memory will
|
||||
* be released and new memory will be allocated. The allocated buffer is then
|
||||
* initialized either with zeros (when @p value == 0.0f) or with the given scalar
|
||||
* value using CANN operations. Finally, an aclTensor object is created from the
|
||||
* cached memory and returned.
|
||||
* be released and new memory will be allocated. The allocated buffer is
|
||||
* initialized with the given scalar value using CANN operations.
|
||||
* Finally, an aclTensor object is created from the cached memory and returned.
|
||||
*
|
||||
* @param ctx The CANN backend context that manages device memory.
|
||||
* @param buffer A pointer to the cached device buffer (will be allocated
|
||||
@@ -910,17 +907,19 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
|
||||
* updated when the cache is expanded.
|
||||
* @param ne The tensor shape array (number of elements in each dimension).
|
||||
* @param nb The stride size for each dimension.
|
||||
* @param dtype Data type of cached tensor.
|
||||
* @param dims The number of tensor dimensions.
|
||||
* @param value The scalar value used to fill the tensor (supports zero
|
||||
* initialization via memset or arbitrary values via fill_scalar).
|
||||
* @return An aclTensor pointer created from the cached buffer.
|
||||
*/
|
||||
static aclTensor* get_f32_cache_acl_tensor(
|
||||
static aclTensor* get_cache_acl_tensor(
|
||||
ggml_backend_cann_context& ctx,
|
||||
void** buffer,
|
||||
int64_t &cache_element,
|
||||
int64_t* ne,
|
||||
size_t* nb,
|
||||
ggml_type dtype,
|
||||
int64_t dims,
|
||||
float value) {
|
||||
// Calculate total number of elements
|
||||
@@ -928,7 +927,7 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
for (int i = 0; i < dims; i++) {
|
||||
n_element *= ne[i];
|
||||
}
|
||||
size_t size = n_element * sizeof(float);
|
||||
size_t size = n_element * ggml_type_size(dtype);
|
||||
|
||||
// Allocate or expand cache if needed
|
||||
if (cache_element < n_element) {
|
||||
@@ -941,19 +940,17 @@ static aclTensor* get_f32_cache_acl_tensor(
|
||||
cache_element = n_element;
|
||||
|
||||
// Initialize cache
|
||||
if (value == 0.0f) {
|
||||
ACL_CHECK(aclrtMemsetAsync(*buffer, size, 0, size, ctx.stream()));
|
||||
} else {
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { sizeof(float) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ACL_FLOAT, sizeof(float), pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, 1, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
int64_t pool_ne[1] = { n_element };
|
||||
size_t pool_nb[1] = { ggml_type_size(dtype) };
|
||||
aclTensor* acl_value = ggml_cann_create_tensor(
|
||||
*buffer, ggml_cann_type_mapping(dtype), ggml_type_size(dtype),
|
||||
pool_ne, pool_nb, 1);
|
||||
aclnn_fill_scalar(ctx, value, acl_value);
|
||||
ggml_cann_release_resources(ctx, acl_value);
|
||||
}
|
||||
|
||||
return ggml_cann_create_tensor(*buffer, ACL_FLOAT, sizeof(float), ne, nb, dims);
|
||||
return ggml_cann_create_tensor(*buffer, ggml_cann_type_mapping(dtype),
|
||||
ggml_type_size(dtype), ne, nb, dims);
|
||||
}
|
||||
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
@@ -965,35 +962,39 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
// build gamma, one...
|
||||
// build gamma.
|
||||
size_t acl_gamma_nb[GGML_MAX_DIMS];
|
||||
acl_gamma_nb[0] = sizeof(float);
|
||||
// gamma's type is the same with dst.
|
||||
acl_gamma_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * src->ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_gamma = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_gamma = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_one_tensor_cache.cache,
|
||||
ctx.rms_norm_one_tensor_cache.size,
|
||||
src->ne,
|
||||
acl_gamma_nb,
|
||||
dst->type,
|
||||
1, // dims
|
||||
1.0f // value
|
||||
);
|
||||
|
||||
// build rstd, zero...
|
||||
// build rstd.
|
||||
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
|
||||
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
|
||||
// rstd will always be F32.
|
||||
acl_rstd_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
|
||||
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
|
||||
}
|
||||
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
|
||||
aclTensor* acl_rstd = get_cache_acl_tensor(
|
||||
ctx,
|
||||
&ctx.rms_norm_zero_tensor_cache.cache,
|
||||
ctx.rms_norm_zero_tensor_cache.size,
|
||||
acl_rstd_ne,
|
||||
acl_rstd_nb,
|
||||
GGML_TYPE_F32,
|
||||
GGML_MAX_DIMS - 1,
|
||||
0.0f // value
|
||||
);
|
||||
@@ -1765,33 +1766,35 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0]; // src
|
||||
ggml_tensor* src1 = dst->src[1]; // index
|
||||
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_F16: {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = sizeof(float);
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
if(src0->type == dst->type) {
|
||||
aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
} else {
|
||||
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
|
||||
ggml_cann_pool_alloc src_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst));
|
||||
void* src_trans_buffer = src_buffer_allocator.get();
|
||||
size_t src_trans_nb[GGML_MAX_DIMS];
|
||||
src_trans_nb[0] = dst->nb[0];
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
}
|
||||
aclTensor* src_trans_tensor = ggml_cann_create_tensor(
|
||||
src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type),
|
||||
src0->ne, src_trans_nb, GGML_MAX_DIMS);
|
||||
aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
|
||||
aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
// add 1 dim for bcast mul.
|
||||
size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1],
|
||||
@@ -1799,7 +1802,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1],
|
||||
*dequant_ne;
|
||||
int64_t scale_offset = 0;
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
weight_ne[0] = QK8_0;
|
||||
weight_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1809,7 +1811,6 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
weight_ne[i] = src0->ne[i - 1];
|
||||
weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,1]
|
||||
scale_ne[0] = 1;
|
||||
scale_ne[1] = src0->ne[0] / QK8_0;
|
||||
@@ -1819,18 +1820,15 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
scale_ne[i] = src0->ne[i - 1];
|
||||
scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1];
|
||||
}
|
||||
|
||||
// [3,4,5,64] -> [3,4,5,2,32]
|
||||
dequant_ne = weight_ne;
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
|
||||
}
|
||||
|
||||
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
|
||||
ggml_cann_pool_alloc dequant_buffer_allocator(
|
||||
ctx.pool(), ggml_nelements(src0) * sizeof(float));
|
||||
|
||||
ctx.pool(), ggml_nelements(src0) * ggml_type_size(dst->type));
|
||||
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
|
||||
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
|
||||
GGML_MAX_DIMS + 1);
|
||||
@@ -1838,22 +1836,20 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
|
||||
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
|
||||
aclTensor* dequant_tensor = ggml_cann_create_tensor(
|
||||
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
|
||||
dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
|
||||
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
|
||||
|
||||
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
dequant_nb[0] = sizeof(float);
|
||||
dequant_nb[0] = ggml_type_size(dst->type);
|
||||
dequant_ne = src0->ne;
|
||||
for (int i = 1; i < GGML_MAX_DIMS; i++) {
|
||||
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
|
||||
}
|
||||
|
||||
aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(),
|
||||
dequant_ne, dequant_nb,
|
||||
dst->data, dst->ne, dst->nb,
|
||||
src1, dst->type);
|
||||
|
||||
ggml_cann_release_resources(ctx, dequant_tensor);
|
||||
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -1965,16 +1961,8 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
|
||||
// Only check env once.
|
||||
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
|
||||
if (weight_to_nz && is_matmul_weight(weight)) {
|
||||
int64_t acl_stride[2] = {1, transpose_ne[1]};
|
||||
|
||||
// Reverse ne.
|
||||
std::reverse(transpose_ne, transpose_ne + n_dims);
|
||||
|
||||
std::vector<int64_t> storageDims = {transpose_ne[0], transpose_ne[1]};
|
||||
|
||||
acl_weight_tensor = aclCreateTensor(
|
||||
transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride,
|
||||
0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data);
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
|
||||
} else {
|
||||
acl_weight_tensor =
|
||||
ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND);
|
||||
@@ -3178,7 +3166,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclTensor* acl_src0_f16_tensor = nullptr;
|
||||
aclTensor* acl_src1_f16_tensor = nullptr;
|
||||
aclTensor* acl_src2_f16_tensor = nullptr;
|
||||
aclTensor* acl_dst_f16_tensor = nullptr;
|
||||
|
||||
// Step 1: cast the src0 (Query) to fp16 if needed
|
||||
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
|
||||
@@ -3216,22 +3203,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
|
||||
src2_bsnd_nb, GGML_MAX_DIMS);
|
||||
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
acl_dst_f16_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
|
||||
// Step 3: create the PSEShift tensor if needed
|
||||
// this tensor is considered as mask (f16) in the llama.cpp
|
||||
aclTensor* bcast_pse_tensor = nullptr;
|
||||
@@ -3317,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
|
||||
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
|
||||
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
|
||||
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
|
||||
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
|
||||
|
||||
int64_t numHeads = src0->ne[2]; // N
|
||||
int64_t numKeyValueHeads = src1->ne[2];
|
||||
@@ -3334,8 +3305,29 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
int64_t keyAntiquantMode = 0;
|
||||
int64_t valueAntiquantMode = 0;
|
||||
|
||||
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
|
||||
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
aclTensor * fa_dst_tensor = nullptr;
|
||||
aclTensor * acl_dst_tensor = nullptr;
|
||||
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
void* out_f16_buffer = out_f16_allocator.alloc(
|
||||
ggml_nelements(dst) * faElemSize);
|
||||
|
||||
int64_t* out_f16_ne = src0_bsnd_ne;
|
||||
size_t out_f16_nb[GGML_MAX_DIMS];
|
||||
out_f16_nb[0] = faElemSize;
|
||||
for(int i = 1; i < GGML_MAX_DIMS; ++i){
|
||||
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
|
||||
}
|
||||
|
||||
fa_dst_tensor = ggml_cann_create_tensor(
|
||||
out_f16_buffer, faDataType, faElemSize,
|
||||
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
|
||||
);
|
||||
}
|
||||
else {
|
||||
fa_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
}
|
||||
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
|
||||
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
|
||||
@@ -3357,23 +3349,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
|
||||
blockSize, antiquantMode, // blockSize, antiquantMode
|
||||
softmaxLseFlag, // softmaxLseFlag
|
||||
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
|
||||
acl_dst_f16_tensor, // attentionOut
|
||||
fa_dst_tensor, // attentionOut
|
||||
nullptr // softmaxLse
|
||||
);
|
||||
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
// TODO: when dst is fp16, don't need cast
|
||||
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_src1_f16_tensor,
|
||||
acl_src2_f16_tensor,
|
||||
acl_dst_f16_tensor,
|
||||
acl_dst_tensor);
|
||||
if(src3 != nullptr){
|
||||
ggml_cann_release_resources(ctx, bcast_pse_tensor);
|
||||
if (dst->type == GGML_TYPE_F32) {
|
||||
// Step 6: post-processing, permute and cast to f32
|
||||
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
|
||||
aclnn_cast(ctx, fa_dst_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
|
||||
}
|
||||
}else{
|
||||
|
||||
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
|
||||
acl_k_tensor_list,
|
||||
acl_v_tensor_list,
|
||||
fa_dst_tensor,
|
||||
acl_dst_tensor,
|
||||
bcast_pse_tensor);
|
||||
|
||||
} else {
|
||||
GGML_ABORT("Function is not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,11 +341,18 @@ private:
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
// dst tensor
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
|
||||
// src tensor
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
|
||||
// op
|
||||
ggml_op node_op;
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
|
||||
@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
|
||||
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
|
||||
|
||||
for (int src = 0; src < GGML_MAX_SRC; ++src) {
|
||||
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
|
||||
if (node->src[src]) {
|
||||
prop.src_address[src] = node->src[src]->data;
|
||||
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
|
||||
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
|
||||
} else {
|
||||
prop.src_address[src] = nullptr;
|
||||
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
|
||||
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
|
||||
* @param graph_node_properties The stored properties of a CANN graph node.
|
||||
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
|
||||
*/
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
static bool ggml_graph_node_has_matching_properties(
|
||||
ggml_tensor * node,
|
||||
ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->op != graph_node_properties->node_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (node->ne[i] != graph_node_properties->ne[i]) {
|
||||
return false;
|
||||
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
if (node->src[i]) {
|
||||
if (node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int d = 0; d < GGML_MAX_DIMS; d++) {
|
||||
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
|
||||
return false;
|
||||
}
|
||||
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (graph_node_properties->src_address[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (node->op == GGML_OP_SCALE &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
|
||||
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
|
||||
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
|
||||
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
|
||||
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
|
||||
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
|
||||
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
|
||||
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
|
||||
// src1 must be host buffer
|
||||
|
||||
@@ -68,7 +68,7 @@ struct ggml_compute_params {
|
||||
#endif // __VXE2__
|
||||
#endif // __s390x__ && __VEC__
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
|
||||
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
|
||||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__linux__)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
#else
|
||||
// TODO: add support of SVE for non-linux systems
|
||||
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,108 @@
|
||||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||
return Fn(a, b, c);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
|
||||
return Fn(a, b);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m, k, bl, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m, k, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m_idx, k, bl, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m_idx, k, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
||||
static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
||||
static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
||||
static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
||||
return Fn(n, k, nr, kr, bl);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t)>
|
||||
static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
||||
return Fn(n, k);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
|
||||
static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
|
||||
return Fn(k, nr, kr, bl);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t)>
|
||||
static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
||||
return Fn(k);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
|
||||
static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
||||
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr, bl,
|
||||
static_cast<const uint8_t*>(rhs),
|
||||
static_cast<const float*>(bias),
|
||||
rhs_packed, extra_bytes,
|
||||
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
||||
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
|
||||
}
|
||||
|
||||
static const size_t INT4_PER_BYTE = 2;
|
||||
static const size_t INT4_BITS = 4;
|
||||
static const int Q4_0_ZERO_POINT = 8;
|
||||
@@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
},
|
||||
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ nullptr,
|
||||
/* .get_rhs_packed_offset_ex = */ nullptr,
|
||||
/* .run_kernel_ex = */ nullptr,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .packed_stride = */ NULL,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .to_float = */ NULL,
|
||||
/* .packed_stride = */ nullptr,
|
||||
/* .to_float = */ nullptr,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
@@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
@@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
ggml_kleidiai_kernels * kernel = nullptr;
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
@@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return kernel;
|
||||
@@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <variant>
|
||||
#include "ggml.h"
|
||||
|
||||
enum cpu_feature {
|
||||
@@ -15,6 +13,7 @@ enum cpu_feature {
|
||||
CPU_FEATURE_SVE = 4,
|
||||
CPU_FEATURE_SME = 8
|
||||
};
|
||||
|
||||
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
||||
lhs = static_cast<cpu_feature>(lhs | rhs);
|
||||
return lhs;
|
||||
@@ -30,63 +29,52 @@ struct kernel_info {
|
||||
size_t (*get_nr)(void);
|
||||
size_t (*get_kr)(void);
|
||||
size_t (*get_sr)(void);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t m_idx, size_t k)>
|
||||
> get_lhs_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t n_idx, size_t k)>
|
||||
> get_rhs_packed_offset;
|
||||
|
||||
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
||||
size_t (*get_dst_size)(size_t m, size_t n);
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
|
||||
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
|
||||
size_t dst_stride_col, float clamp_min, float clamp_max)>
|
||||
> run_kernel;
|
||||
|
||||
size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
|
||||
|
||||
size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
|
||||
|
||||
void (*run_kernel_ex)(
|
||||
size_t m, size_t n, size_t k, size_t bl,
|
||||
const void* lhs_packed, const void* rhs_packed,
|
||||
void* dst, size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max);
|
||||
};
|
||||
|
||||
struct lhs_packing_info {
|
||||
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> get_packed_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed)>,
|
||||
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
||||
void* lhs_packed)>
|
||||
> pack_func;
|
||||
|
||||
size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
|
||||
size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
|
||||
void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
std::variant<
|
||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||
std::function<size_t(size_t n, size_t k)>
|
||||
> packed_size;
|
||||
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||
std::variant<
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||
> pack_func;
|
||||
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
||||
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
||||
|
||||
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
|
||||
size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
|
||||
size_t num_bytes_multiplier);
|
||||
|
||||
size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
||||
|
||||
size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||
|
||||
void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
||||
size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
kernel_info gemm;
|
||||
kernel_info gemm;
|
||||
lhs_packing_info gemm_lhs_info;
|
||||
|
||||
kernel_info gemv;
|
||||
kernel_info gemv;
|
||||
lhs_packing_info gemv_lhs_info;
|
||||
|
||||
rhs_packing_info rhs_info;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <string>
|
||||
#if defined(__linux__)
|
||||
#include <asm/hwcap.h>
|
||||
#include <sys/auxv.h>
|
||||
@@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
|
||||
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
|
||||
using V = std::remove_reference_t<Variant>;
|
||||
return (std::is_invocable_r_v<
|
||||
Ret,
|
||||
std::variant_alternative_t<Is, V>,
|
||||
Args...> || ...);
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args>
|
||||
constexpr bool variant_any_invocable_v =
|
||||
variant_any_invocable_impl<Variant, Ret, Args...>(
|
||||
std::make_index_sequence<
|
||||
std::variant_size_v<std::remove_reference_t<Variant>>>{});
|
||||
|
||||
template<typename Ret, typename Variant, typename... Args>
|
||||
static inline Ret variant_call(Variant && var, Args&&... args) {
|
||||
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
|
||||
"No alternative in Variant is invocable with the provided arguments and return type.");
|
||||
|
||||
return std::visit(
|
||||
[&](auto && f) -> Ret {
|
||||
using F = std::decay_t<decltype(f)>;
|
||||
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
|
||||
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
|
||||
} else {
|
||||
GGML_ABORT("Invalid function type in variant_call");
|
||||
GGML_UNREACHABLE();
|
||||
}
|
||||
},
|
||||
std::forward<Variant>(var)
|
||||
);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
|
||||
static size_t round_down(size_t x, size_t y) {
|
||||
@@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
return false;
|
||||
}
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
bool is_gemv = op->src[1]->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
@@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
size_t sr = kernel->get_sr();
|
||||
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
if (!lhs_info->packed_size_ex) return false;
|
||||
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
|
||||
kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
GGML_ASSERT(kernel);
|
||||
if (!kernels->rhs_info.pack_func_ex ||
|
||||
!kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int nth = params->nth;
|
||||
const int ith = params->ith;
|
||||
@@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const int64_t kr = (int64_t) kernel->get_kr();
|
||||
const int64_t sr = (int64_t) kernel->get_sr();
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
|
||||
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
|
||||
const size_t bias_size = (size_t)n * sizeof(float);
|
||||
const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
|
||||
const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
|
||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||
GGML_ASSERT(wsize_required <= params->wsize);
|
||||
@@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
// Base packed offset (aligned) and per-row stride in bytes
|
||||
const size_t base_packed_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t next_block_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
||||
const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
|
||||
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
||||
|
||||
int64_t remaining = m_count;
|
||||
@@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
||||
void * dst_ptr = lhs_packed + dst_off;
|
||||
|
||||
variant_call<void>(lhs_info->pack_func,
|
||||
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
|
||||
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
|
||||
lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
|
||||
cur += take;
|
||||
remaining -= take;
|
||||
@@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
||||
rhs_stride);
|
||||
|
||||
variant_call<void>(kernels->rhs_info.pack_func,
|
||||
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
|
||||
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
|
||||
kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
@@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
|
||||
// LHS packed base at row 0 (consistent with packing above)
|
||||
const size_t lhs_packed_offset0 = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
|
||||
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||
const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
||||
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel,
|
||||
(size_t)m, (size_t)n_to_process, (size_t)k,
|
||||
lhs_ptr, rhs_ptr,
|
||||
dst_ptr, dst_stride, sizeof(float),
|
||||
-FLT_MAX, FLT_MAX);
|
||||
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
||||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth_raw = params->nth;
|
||||
@@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
// Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
|
||||
lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
if (n_to_process > 0) {
|
||||
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
|
||||
@@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
|
||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
if (!ctx.kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
|
||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||
kernel_info * kernel = &ctx.kernels->gemm;
|
||||
if (!rhs_info->to_float || !kernel->get_nr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
@@ -480,7 +458,7 @@ public:
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
GGML_UNUSED(data_size);
|
||||
@@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_
|
||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||
|
||||
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
+11
-15
@@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)x[i00];
|
||||
}
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
|
||||
ggml_float sum2 = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
float v = x[i00] - mean;
|
||||
y[i00] = v;
|
||||
sum2 += (ggml_float)(v*v);
|
||||
}
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
float variance = sum2/ne00;
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
@@ -8135,7 +8131,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f/S;
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
|
||||
@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
||||
}
|
||||
}
|
||||
|
||||
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
for (; i + 15 < n; i += 16) {
|
||||
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
|
||||
_mm512_set1_ps(mean));
|
||||
_mm512_storeu_ps(y + i, val);
|
||||
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
|
||||
}
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
for (; i + 7 < n; i += 8) {
|
||||
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
|
||||
_mm256_set1_ps(mean));
|
||||
_mm256_storeu_ps(y + i, val);
|
||||
val = _mm256_mul_ps(val,val);
|
||||
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
|
||||
_mm256_castps256_ps128(val));
|
||||
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
||||
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
||||
sum += (ggml_float)_mm_cvtss_f32(val2);
|
||||
}
|
||||
#elif defined(__SSE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
|
||||
_mm_set1_ps(mean));
|
||||
_mm_storeu_ps(y + i, val);
|
||||
val = _mm_mul_ps(val, val);
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
||||
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
||||
#else
|
||||
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
val = _mm_add_ps(val, tmp);
|
||||
tmp = _mm_movehl_ps(tmp, val);
|
||||
val = _mm_add_ss(val, tmp);
|
||||
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
||||
sum += (ggml_float)_mm_cvtss_f32(val);
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = vsubq_f32(vld1q_f32(x + i),
|
||||
vdupq_n_f32(mean));
|
||||
vst1q_f32(y + i, val);
|
||||
val = vmulq_f32(val, val);
|
||||
sum += (ggml_float)vaddvq_f32(val);
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
|
||||
vec_xst(val, 0, y + i);
|
||||
val = vec_mul(val, val);
|
||||
sum += (ggml_float)vec_hsum_f32x4(val);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
y[i] = val;
|
||||
val *= val;
|
||||
sum += (ggml_float)val;
|
||||
}
|
||||
return sum/n;
|
||||
}
|
||||
|
||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
|
||||
+10
-8
@@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_silu_f32(const int n, float * y, const float * x);
|
||||
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
|
||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
|
||||
@@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
@@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
||||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
@@ -654,11 +655,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||
}
|
||||
// leftovers
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
||||
if (np < n) {
|
||||
svbool_t pg = svwhilelt_b32(np, n);
|
||||
ay1 = svld1_f32(pg, y + np);
|
||||
for (int i = np; i < n; i += ggml_f32_epr) {
|
||||
svbool_t pg = svwhilelt_b32(i, n);
|
||||
ay1 = svld1_f32(pg, y + i);
|
||||
ay1 = svmul_f32_m(pg, ay1, vx);
|
||||
svst1_f32(pg, y + np, ay1);
|
||||
svst1_f32(pg, y + i, ay1);
|
||||
}
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
for (int i = 0, avl; i < n; i += avl) {
|
||||
@@ -819,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
|
||||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
|
||||
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
|
||||
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
|
||||
@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
|
||||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
||||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
// Important: do not use this function if dst and src both point at registers.
|
||||
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
|
||||
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (alignment != 0) {
|
||||
|
||||
@@ -793,8 +793,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
@@ -878,7 +876,7 @@ void launch_fattn(
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
@@ -916,8 +914,7 @@ void launch_fattn(
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
@@ -946,7 +943,7 @@ void launch_fattn(
|
||||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
||||
@@ -1,756 +1,45 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
if (fast_fp16_available(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
GGML_UNUSED(warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
|
||||
|
||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (k_KQ_0 + kq_nbatch < D) {
|
||||
__syncthreads(); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Attention sink: adjust running max and sum once per head
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
template <bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
switch (K->ne[0]) {
|
||||
case 40: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
||||
} break;
|
||||
case 64: {
|
||||
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
} break;
|
||||
case 112: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
||||
} break;
|
||||
case 128: {
|
||||
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("Unsupported head size");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
|
||||
+41
-29
@@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// FLASH_ATTN_AVAILABLE
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
@@ -206,31 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
// The effective batch size for the kernel can be increased by gqa_ratio.
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
@@ -264,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
|
||||
|
||||
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||
if (turing_mma_available(cc)) {
|
||||
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 2) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
|
||||
if (!gqa_opt_applies && Q->ne[1] == 1) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use kernels specialized for small batch sizes if possible:
|
||||
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
// For large batch sizes, use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc)) {
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
|
||||
info.default_tensor_split[id] = total_vram;
|
||||
total_vram += prop.totalGlobalMem;
|
||||
info.devices[id].integrated = prop.integrated;
|
||||
info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
@@ -3867,7 +3867,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(112, 112);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(128, 128);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(256, 256);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(40, 40);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(576, 512);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(64, 64);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(80, 80);
|
||||
@@ -0,0 +1,5 @@
|
||||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(96, 96);
|
||||
@@ -3,8 +3,17 @@
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
|
||||
for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type_k in TYPES_KV:
|
||||
for type_v in TYPES_KV:
|
||||
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
|
||||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
if head_size_kq == 40:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
|
||||
@@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
@@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i]) {
|
||||
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
||||
}
|
||||
@@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor *
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (tensor->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
||||
return false;
|
||||
|
||||
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_SUM);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
|
||||
@@ -338,7 +357,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
||||
const char * suffix = "";
|
||||
|
||||
if (op->src[1]->ne[0] % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -352,15 +377,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
if (op->src[3]->ne[0] == 1) {
|
||||
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
||||
}
|
||||
snprintf(name, 256, "%s", base);
|
||||
const int nsg = (ne00 + 31)/32;
|
||||
|
||||
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -369,7 +394,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -918,6 +943,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_pad");
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
||||
//ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
GGML_UNUSED(op);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_%s",
|
||||
"flash_attn_ext_blk");
|
||||
|
||||
snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
|
||||
base,
|
||||
nqptg,
|
||||
ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
//ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
|
||||
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
|
||||
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
|
||||
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
|
||||
|
||||
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
|
||||
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
|
||||
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
|
||||
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
|
||||
ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
|
||||
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
const ggml_tensor * op,
|
||||
@@ -925,6 +1040,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
@@ -937,18 +1053,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
|
||||
const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
|
||||
|
||||
// do bounds checks for the mask?
|
||||
const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
bc_mask,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg);
|
||||
@@ -964,6 +1085,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||
@@ -983,6 +1107,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
@@ -1002,12 +1127,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
has_scap,
|
||||
has_kvpad,
|
||||
ns10,
|
||||
ns20,
|
||||
nsg, nwg);
|
||||
@@ -1023,6 +1149,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
||||
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
||||
|
||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
||||
@@ -1374,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_OPT_STEP_SGD);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
@@ -134,6 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
bool has_mask,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
|
||||
ggml_metal_library_t lib,
|
||||
const struct ggml_tensor * op,
|
||||
int32_t nqptg,
|
||||
int32_t ncpsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
ggml_metal_library_t lib,
|
||||
@@ -142,6 +157,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
@@ -151,6 +167,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
bool has_sinks,
|
||||
bool has_bias,
|
||||
bool has_scap,
|
||||
bool has_kvpad,
|
||||
int32_t nsg,
|
||||
int32_t nwg);
|
||||
|
||||
|
||||
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
@@ -776,9 +777,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
};
|
||||
}
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
@@ -800,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return has_simdgroup_reduction;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -69,11 +69,20 @@
|
||||
#define N_SG_IQ4_XS 2
|
||||
|
||||
// function constants offsets
|
||||
#define FC_FLASH_ATTN_EXT 100
|
||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
||||
#define FC_MUL_MV 400
|
||||
#define FC_MUL_MM 500
|
||||
#define FC_FLASH_ATTN_EXT_PAD 100
|
||||
#define FC_FLASH_ATTN_EXT_BLK 200
|
||||
#define FC_FLASH_ATTN_EXT 300
|
||||
#define FC_FLASH_ATTN_EXT_VEC 400
|
||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
|
||||
#define FC_MUL_MV 600
|
||||
#define FC_MUL_MM 700
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||
#define OP_FLASH_ATTN_EXT_NCPSG 64
|
||||
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
@@ -178,6 +187,7 @@ typedef struct {
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t nk0;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
@@ -243,6 +253,35 @@ typedef struct {
|
||||
int32_t sect_3;
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_pad;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne30;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
} ggml_metal_kargs_flash_attn_ext_blk;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
@@ -261,6 +300,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
@@ -295,6 +335,7 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne31;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
@@ -503,6 +544,10 @@ typedef struct{
|
||||
float limit;
|
||||
} ggml_metal_kargs_glu;
|
||||
|
||||
typedef struct {
|
||||
uint64_t np;
|
||||
} ggml_metal_kargs_sum;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
@@ -572,32 +617,45 @@ typedef struct {
|
||||
int64_t n_seq_tokens;
|
||||
int64_t n_seqs;
|
||||
uint64_t s_off;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t ns12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb20;
|
||||
uint64_t nb21;
|
||||
uint64_t ns21;
|
||||
uint64_t nb22;
|
||||
int64_t ne30;
|
||||
uint64_t nb31;
|
||||
uint64_t nb41;
|
||||
uint64_t nb42;
|
||||
uint64_t ns42;
|
||||
uint64_t nb43;
|
||||
uint64_t nb51;
|
||||
uint64_t nb52;
|
||||
uint64_t ns52;
|
||||
uint64_t nb53;
|
||||
uint64_t nb0;
|
||||
} ggml_metal_kargs_ssm_scan;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int32_t ne00t;
|
||||
int32_t ne00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
int64_t ne10;
|
||||
uint64_t nb03;
|
||||
int32_t ne10;
|
||||
uint64_t nb10;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_get_rows;
|
||||
|
||||
typedef struct {
|
||||
@@ -719,4 +777,12 @@ typedef struct {
|
||||
uint64_t nb01;
|
||||
} ggml_metal_kargs_argmax;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_adamw;
|
||||
|
||||
typedef struct {
|
||||
int64_t np;
|
||||
} ggml_metal_kargs_opt_step_sgd;
|
||||
|
||||
#endif // GGML_METAL_IMPL
|
||||
|
||||
@@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
||||
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
||||
|
||||
@@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
||||
}
|
||||
if (node->src[2]) {
|
||||
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
||||
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
||||
}
|
||||
if (node->src[3]) {
|
||||
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
||||
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
||||
}
|
||||
if (node) {
|
||||
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
node->name);
|
||||
@@ -289,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_glu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
{
|
||||
n_fuse = ggml_metal_op_sum(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
@@ -398,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_argmax(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
||||
@@ -577,6 +601,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
@@ -827,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
|
||||
|
||||
ggml_metal_kargs_sum args = {
|
||||
/*.np =*/ n,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -906,23 +955,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
|
||||
|
||||
ggml_metal_kargs_get_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne10 =*/ ne10,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nw0 = (args.ne00t + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -1117,7 +1174,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
|
||||
|
||||
@@ -1172,25 +1229,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
/*.n_seq_tokens =*/ n_seq_tokens,
|
||||
/*.n_seqs =*/ n_seqs,
|
||||
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.ns12 =*/ nb12/nb10,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb20 =*/ nb20,
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.ns21 =*/ nb21/nb20,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.ne30 =*/ ne30,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb41 =*/ nb41,
|
||||
/*.nb42 =*/ nb42,
|
||||
/*.ns42 =*/ nb42/nb40,
|
||||
/*.nb43 =*/ nb43,
|
||||
/*.nb51 =*/ nb51,
|
||||
/*.nb52 =*/ nb52,
|
||||
/*.ns52 =*/ nb52/nb50,
|
||||
/*.nb53 =*/ nb53,
|
||||
/*.nb0 =*/ nb0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
|
||||
|
||||
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
@@ -1206,13 +1274,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
|
||||
|
||||
if (ne30 == 1) {
|
||||
// Mamba-2
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
} else {
|
||||
GGML_ASSERT(d_inner == 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
|
||||
}
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -1273,26 +1335,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
|
||||
|
||||
// TODO: support
|
||||
//const int32_t nk00 = ne00/ggml_blck_size(op->type);
|
||||
const int32_t nk00 = ne00;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
int64_t nk0 = ne00;
|
||||
if (ggml_is_quantized(op->src[0]->type)) {
|
||||
nk0 = ne00/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne00/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk00) {
|
||||
nrptg = (nth + nk00 - 1)/nk00;
|
||||
nth = nk00;
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
@@ -1300,10 +1359,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min(nth, nk00);
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ nk00,
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
@@ -1321,12 +1381,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -1520,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
|
||||
//GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
@@ -1875,20 +1936,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
||||
return (ne01 < 20) && (ne00 % 32 == 0);
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
} else {
|
||||
const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
res += OP_FLASH_ATTN_EXT_NCPSG*(
|
||||
nb11*ne12*ne13 +
|
||||
nb21*ne22*ne23 +
|
||||
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
size_t res = 0;
|
||||
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
|
||||
if (!has_mask) {
|
||||
return res;
|
||||
}
|
||||
|
||||
const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
|
||||
|
||||
// this optimization is not useful for the vector kernels
|
||||
if (is_vec) {
|
||||
return res;
|
||||
}
|
||||
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||
|
||||
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
||||
const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
|
||||
|
||||
res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int64_t nwg = 32;
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||
|
||||
const int64_t ne01 = op->src[0]->ne[1];
|
||||
const int64_t ne02 = op->src[0]->ne[2];
|
||||
const int64_t ne03 = op->src[0]->ne[3];
|
||||
const int64_t ne20 = op->src[2]->ne[0];
|
||||
size_t res = 0;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
const int64_t nwg = 32;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the Value head
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
@@ -1910,8 +2058,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
||||
@@ -1921,8 +2068,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(ne12 == ne22);
|
||||
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||
GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
|
||||
"the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
@@ -1949,15 +2096,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
GGML_ASSERT(ne01 < 65536);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
||||
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
||||
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||
bid_pad.offs += ggml_nbytes(op);
|
||||
|
||||
ggml_metal_buffer_id bid_blk = bid_pad;
|
||||
bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||
|
||||
ggml_metal_buffer_id bid_tmp = bid_blk;
|
||||
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
|
||||
|
||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// half8x8 kernel
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 8 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
bool need_sync = false;
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (has_mask) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_blk args0 = {
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne30 =*/ ne30,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb32 =*/ nb32,
|
||||
/*.nb33 =*/ nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
|
||||
|
||||
const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
|
||||
const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
||||
|
||||
// 2*(2*ncpsg)
|
||||
@@ -2007,6 +2250,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2023,24 +2267,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -2048,14 +2286,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
#undef FATTN_SMEM
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int64_t nkpsg = 1*ncpsg;
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int nkpsg = 1*ncpsg;
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
GGML_ASSERT(ncpsg % 32 == 0);
|
||||
|
||||
bool need_sync = false;
|
||||
|
||||
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||
|
||||
if (has_kvpad) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||
|
||||
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||
/*.ne11 =*/ne11,
|
||||
/*.ne_12_2 =*/ne12,
|
||||
/*.ne_12_3 =*/ne13,
|
||||
/*.nb11 =*/nb11,
|
||||
/*.nb12 =*/nb12,
|
||||
/*.nb13 =*/nb13,
|
||||
/*.nb21 =*/nb21,
|
||||
/*.nb22 =*/nb22,
|
||||
/*.nb23 =*/nb23,
|
||||
/*.ne31 =*/ne31,
|
||||
/*.ne32 =*/ne32,
|
||||
/*.ne33 =*/ne33,
|
||||
/*.nb31 =*/nb31,
|
||||
/*.nb32 =*/nb32,
|
||||
/*.nb33 =*/nb33,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||
|
||||
assert(ne12 == ne22);
|
||||
assert(ne13 == ne23);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||
|
||||
need_sync = true;
|
||||
} else {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||
}
|
||||
|
||||
if (need_sync) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
@@ -2120,6 +2406,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne31 =*/ ne31,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
@@ -2136,25 +2423,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
/*.logit_softcap =*/ logit_softcap,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||
|
||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
||||
if (op->src[3]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
||||
}
|
||||
if (op->src[4]) {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
||||
} else {
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||
}
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
@@ -2162,23 +2441,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
||||
|
||||
if (nwg == 1) {
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
||||
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
} else {
|
||||
// sanity checks
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
||||
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
// write the results from each workgroup into a temp buffer
|
||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||
bid_tmp.offs += ggml_nbytes(op);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
@@ -3156,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_adamw args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
||||
|
||||
const int64_t np = ggml_nelements(op->src[0]);
|
||||
ggml_metal_kargs_opt_step_sgd args = {
|
||||
/*.np =*/ np,
|
||||
};
|
||||
|
||||
int ida = 0;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
||||
const int64_t n = (np + nth - 1) / nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||
// return true if we should use the FA vector kernel for this op
|
||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
|
||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
@@ -48,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
|
||||
@@ -76,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -193,9 +193,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
}
|
||||
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
|
||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
|
||||
|
||||
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_MUSA ${SRCS})
|
||||
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
||||
|
||||
+301
-137
@@ -105,9 +105,12 @@ enum rpc_cmd {
|
||||
RPC_CMD_INIT_TENSOR,
|
||||
RPC_CMD_GET_ALLOC_SIZE,
|
||||
RPC_CMD_HELLO,
|
||||
RPC_CMD_DEVICE_COUNT,
|
||||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
|
||||
|
||||
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
||||
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
||||
|
||||
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
|
||||
uint8_t patch;
|
||||
};
|
||||
|
||||
struct rpc_msg_device_count_rsp {
|
||||
uint32_t device_count;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_alloc_size_req {
|
||||
uint32_t device;
|
||||
rpc_tensor tensor;
|
||||
};
|
||||
|
||||
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
|
||||
};
|
||||
|
||||
struct rpc_msg_alloc_buffer_req {
|
||||
uint32_t device;
|
||||
uint64_t size;
|
||||
};
|
||||
|
||||
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
|
||||
uint64_t remote_size;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_alignment_req {
|
||||
uint32_t device;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_alignment_rsp {
|
||||
uint64_t alignment;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_max_size_req {
|
||||
uint32_t device;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_max_size_rsp {
|
||||
uint64_t max_size;
|
||||
};
|
||||
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_device_memory_req {
|
||||
uint32_t device;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_device_memory_rsp {
|
||||
uint64_t free_mem;
|
||||
uint64_t total_mem;
|
||||
@@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
||||
|
||||
struct ggml_backend_rpc_buffer_type_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
size_t alignment;
|
||||
size_t max_size;
|
||||
size_t alignment;
|
||||
size_t max_size;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
@@ -608,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
|
||||
RPC_STATUS_ASSERT(status);
|
||||
}
|
||||
|
||||
static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
|
||||
return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
|
||||
}
|
||||
|
||||
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||
// check if src and dst are on the same server
|
||||
ggml_backend_buffer_t src_buffer = src->buffer;
|
||||
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
||||
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
||||
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
||||
if (src_ctx->sock != dst_ctx->sock) {
|
||||
return false;
|
||||
if (ggml_backend_buffer_is_rpc(src->buffer)) {
|
||||
// check if src and dst are on the same server
|
||||
ggml_backend_buffer_t src_buffer = src->buffer;
|
||||
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
||||
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
||||
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
||||
if (src_ctx->sock != dst_ctx->sock) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
rpc_msg_copy_tensor_req request;
|
||||
request.src = serialize_tensor(src);
|
||||
request.dst = serialize_tensor(dst);
|
||||
rpc_msg_copy_tensor_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return response.result;
|
||||
}
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
rpc_msg_copy_tensor_req request;
|
||||
request.src = serialize_tensor(src);
|
||||
request.dst = serialize_tensor(dst);
|
||||
rpc_msg_copy_tensor_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return response.result;
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
@@ -653,7 +683,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
rpc_msg_alloc_buffer_req request = {size};
|
||||
rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
|
||||
rpc_msg_alloc_buffer_rsp response;
|
||||
auto sock = get_socket(buft_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
|
||||
@@ -669,9 +699,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
||||
}
|
||||
}
|
||||
|
||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
||||
static size_t get_alignment(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
||||
rpc_msg_get_alignment_req request = {device};
|
||||
rpc_msg_get_alignment_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return response.alignment;
|
||||
}
|
||||
@@ -681,9 +712,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
|
||||
return buft_ctx->alignment;
|
||||
}
|
||||
|
||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
||||
static size_t get_max_size(const std::shared_ptr<socket_t> & sock, uint32_t device) {
|
||||
rpc_msg_get_max_size_req request = {device};
|
||||
rpc_msg_get_max_size_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return response.max_size;
|
||||
}
|
||||
@@ -700,7 +732,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
|
||||
auto sock = get_socket(buft_ctx->endpoint);
|
||||
|
||||
rpc_msg_get_alloc_size_req request;
|
||||
|
||||
request.device = buft_ctx->device;
|
||||
request.tensor = serialize_tensor(tensor);
|
||||
|
||||
rpc_msg_get_alloc_size_rsp response;
|
||||
@@ -754,7 +786,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors,
|
||||
tensors.push_back(serialize_tensor(tensor));
|
||||
}
|
||||
|
||||
static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
||||
static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
|
||||
uint32_t n_nodes = cgraph->n_nodes;
|
||||
std::vector<rpc_tensor> tensors;
|
||||
std::unordered_set<ggml_tensor*> visited;
|
||||
@@ -762,24 +794,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
|
||||
add_tensor(cgraph->nodes[i], tensors, visited);
|
||||
}
|
||||
// serialization format:
|
||||
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
uint32_t n_tensors = tensors.size();
|
||||
int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
||||
int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
|
||||
output.resize(output_size, 0);
|
||||
memcpy(output.data(), &n_nodes, sizeof(n_nodes));
|
||||
uint8_t * dest = output.data();
|
||||
memcpy(dest, &device, sizeof(device));
|
||||
dest += sizeof(device);
|
||||
memcpy(dest, &n_nodes, sizeof(n_nodes));
|
||||
dest += sizeof(n_nodes);
|
||||
for (uint32_t i = 0; i < n_nodes; i++) {
|
||||
memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
||||
memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
|
||||
}
|
||||
uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
|
||||
*out_ntensors = n_tensors;
|
||||
rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
|
||||
dest += n_nodes * sizeof(uint64_t);
|
||||
memcpy(dest, &n_tensors, sizeof(n_tensors));
|
||||
dest += sizeof(n_tensors);
|
||||
rpc_tensor * out_tensors = (rpc_tensor *)dest;
|
||||
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
|
||||
}
|
||||
|
||||
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
std::vector<uint8_t> input;
|
||||
serialize_graph(cgraph, input);
|
||||
serialize_graph(rpc_ctx->device, cgraph, input);
|
||||
rpc_msg_graph_compute_rsp response;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
|
||||
@@ -804,12 +841,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
||||
/* .graph_optimize = */ NULL,
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
||||
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
||||
// NOTE: buffer types are allocated and never freed; this is by design
|
||||
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
||||
auto it = buft_map.find(endpoint);
|
||||
auto it = buft_map.find(buft_name);
|
||||
if (it != buft_map.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -818,34 +856,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
||||
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
|
||||
return nullptr;
|
||||
}
|
||||
size_t alignment = get_alignment(sock);
|
||||
size_t max_size = get_max_size(sock);
|
||||
size_t alignment = get_alignment(sock, device);
|
||||
size_t max_size = get_max_size(sock, device);
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
||||
/* .device = */ device,
|
||||
/* .name = */ buft_name,
|
||||
/* .alignment = */ alignment,
|
||||
/* .max_size = */ max_size
|
||||
};
|
||||
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
||||
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
||||
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
||||
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
||||
/* .context = */ buft_ctx
|
||||
};
|
||||
buft_map[endpoint] = buft;
|
||||
buft_map[buft_name] = buft;
|
||||
return buft;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
||||
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
|
||||
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
|
||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ device,
|
||||
/* .name = */ dev_name
|
||||
};
|
||||
|
||||
auto reg = ggml_backend_rpc_add_server(endpoint);
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_rpc_guid(),
|
||||
/* .iface = */ ggml_backend_rpc_interface,
|
||||
/* .device = */ ggml_backend_rpc_add_device(endpoint),
|
||||
/* .device = */ ggml_backend_reg_dev_get(reg, device),
|
||||
/* .context = */ ctx
|
||||
};
|
||||
return backend;
|
||||
@@ -855,37 +896,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
||||
}
|
||||
|
||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
|
||||
static void get_device_memory(const std::shared_ptr<socket_t> & sock, uint32_t device, size_t * free, size_t * total) {
|
||||
rpc_msg_get_device_memory_req request;
|
||||
request.device = device;
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
*free = response.free_mem;
|
||||
*total = response.total_mem;
|
||||
}
|
||||
|
||||
void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
||||
void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
|
||||
auto sock = get_socket(endpoint);
|
||||
if (sock == nullptr) {
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
return;
|
||||
}
|
||||
get_device_memory(sock, free, total);
|
||||
get_device_memory(sock, device, free, total);
|
||||
}
|
||||
|
||||
// RPC server-side implementation
|
||||
|
||||
class rpc_server {
|
||||
public:
|
||||
rpc_server(ggml_backend_t backend, const char * cache_dir)
|
||||
: backend(backend), cache_dir(cache_dir) {
|
||||
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
|
||||
: backends(std::move(backends)), cache_dir(cache_dir) {
|
||||
}
|
||||
~rpc_server();
|
||||
|
||||
void hello(rpc_msg_hello_rsp & response);
|
||||
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
||||
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
||||
bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||
bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
|
||||
bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
|
||||
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
|
||||
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||
@@ -906,7 +949,7 @@ private:
|
||||
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
|
||||
|
||||
|
||||
ggml_backend_t backend;
|
||||
std::vector<ggml_backend_t> backends;
|
||||
const char * cache_dir;
|
||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||
};
|
||||
@@ -919,6 +962,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
||||
}
|
||||
|
||||
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_type_t buft;
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
@@ -935,10 +982,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
||||
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
||||
return false;
|
||||
}
|
||||
LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
|
||||
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
||||
if (tensor->buffer == nullptr) {
|
||||
//No buffer allocated.
|
||||
buft = ggml_backend_get_default_buffer_type(backend);
|
||||
buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
||||
} else {
|
||||
buft = tensor->buffer->buft;
|
||||
}
|
||||
@@ -948,33 +995,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
||||
return true;
|
||||
}
|
||||
|
||||
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
|
||||
response.remote_ptr = 0;
|
||||
response.remote_size = 0;
|
||||
if (buffer != nullptr) {
|
||||
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
|
||||
response.remote_size = buffer->size;
|
||||
LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
|
||||
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
|
||||
__func__, dev_id, request.size, response.remote_ptr, response.remote_size);
|
||||
buffers.insert(buffer);
|
||||
} else {
|
||||
LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
|
||||
LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||
LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
|
||||
LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
|
||||
response.alignment = alignment;
|
||||
return true;
|
||||
}
|
||||
|
||||
void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
|
||||
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
||||
LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
|
||||
LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
|
||||
response.max_size = max_size;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
|
||||
@@ -1332,23 +1395,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
||||
|
||||
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
|
||||
// serialization format:
|
||||
// | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
if (input.size() < sizeof(uint32_t)) {
|
||||
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
|
||||
if (input.size() < 2*sizeof(uint32_t)) {
|
||||
return false;
|
||||
}
|
||||
const uint8_t * src = input.data();
|
||||
uint32_t device;
|
||||
memcpy(&device, src, sizeof(device));
|
||||
src += sizeof(device);
|
||||
if (device >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
uint32_t n_nodes;
|
||||
memcpy(&n_nodes, input.data(), sizeof(n_nodes));
|
||||
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
||||
memcpy(&n_nodes, src, sizeof(n_nodes));
|
||||
src += sizeof(n_nodes);
|
||||
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
|
||||
return false;
|
||||
}
|
||||
const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
|
||||
const uint64_t * nodes = (const uint64_t *)src;
|
||||
src += n_nodes*sizeof(uint64_t);
|
||||
uint32_t n_tensors;
|
||||
memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
|
||||
if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
||||
memcpy(&n_tensors, src, sizeof(n_tensors));
|
||||
src += sizeof(n_tensors);
|
||||
if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
|
||||
LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
|
||||
const rpc_tensor * tensors = (const rpc_tensor *)src;
|
||||
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
||||
|
||||
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
|
||||
@@ -1380,7 +1453,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backend, graph);
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
response.result = status;
|
||||
return true;
|
||||
}
|
||||
@@ -1391,9 +1464,9 @@ rpc_server::~rpc_server() {
|
||||
}
|
||||
}
|
||||
|
||||
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
rpc_server server(backend, cache_dir);
|
||||
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
||||
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
|
||||
rpc_server server(backends, cache_dir);
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
return;
|
||||
@@ -1425,13 +1498,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
// HELLO command is handled above
|
||||
return;
|
||||
}
|
||||
case RPC_CMD_DEVICE_COUNT: {
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_device_count_rsp response;
|
||||
response.device_count = backends.size();
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_ALLOC_BUFFER: {
|
||||
rpc_msg_alloc_buffer_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_alloc_buffer_rsp response;
|
||||
server.alloc_buffer(request, response);
|
||||
if (!server.alloc_buffer(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
@@ -1452,22 +1538,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_ALIGNMENT: {
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
rpc_msg_get_alignment_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_alignment_rsp response;
|
||||
server.get_alignment(response);
|
||||
if (!server.get_alignment(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_MAX_SIZE: {
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
rpc_msg_get_max_size_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_max_size_rsp response;
|
||||
server.get_max_size(response);
|
||||
if (!server.get_max_size(request, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
@@ -1593,12 +1685,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_GET_DEVICE_MEMORY: {
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
rpc_msg_get_device_memory_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
auto dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
response.free_mem = free_mem;
|
||||
response.total_mem = total_mem;
|
||||
response.free_mem = free_mem[dev_id];
|
||||
response.total_mem = total_mem[dev_id];
|
||||
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
|
||||
response.free_mem, response.total_mem);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
@@ -1612,16 +1711,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
||||
const char * cache_dir,
|
||||
size_t free_mem, size_t total_mem) {
|
||||
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
|
||||
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
|
||||
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
||||
return;
|
||||
}
|
||||
std::vector<ggml_backend_t> backends;
|
||||
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
|
||||
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
|
||||
printf("Starting RPC server v%d.%d.%d\n",
|
||||
RPC_PROTO_MAJOR_VERSION,
|
||||
RPC_PROTO_MINOR_VERSION,
|
||||
RPC_PROTO_PATCH_VERSION);
|
||||
printf(" endpoint : %s\n", endpoint);
|
||||
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
||||
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
||||
printf("Devices:\n");
|
||||
for (size_t i = 0; i < n_devices; i++) {
|
||||
auto dev = devices[i];
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
|
||||
auto backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
||||
return;
|
||||
}
|
||||
backends.push_back(backend);
|
||||
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
|
||||
if (reg) {
|
||||
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
|
||||
if (ggml_backend_set_n_threads_fn) {
|
||||
ggml_backend_set_n_threads_fn(backend, n_threads);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
@@ -1649,22 +1773,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
||||
fprintf(stderr, "Failed to accept client connection\n");
|
||||
return;
|
||||
}
|
||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||
printf("Accepted client connection\n");
|
||||
fflush(stdout);
|
||||
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
|
||||
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
|
||||
printf("Client connection closed\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
for (auto backend : backends) {
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
}
|
||||
|
||||
// device interface
|
||||
|
||||
struct ggml_backend_rpc_device_context {
|
||||
std::string endpoint;
|
||||
uint32_t device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -1676,15 +1805,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
||||
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
||||
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
|
||||
return ctx->name.c_str();
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
|
||||
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
||||
@@ -1710,7 +1837,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
|
||||
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
|
||||
return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
||||
return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
|
||||
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
@@ -1718,7 +1845,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
|
||||
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
|
||||
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
||||
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
@@ -1736,7 +1863,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
|
||||
}
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
||||
return buft_ctx->endpoint == dev_ctx->endpoint;
|
||||
return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
|
||||
}
|
||||
|
||||
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
||||
@@ -1759,28 +1886,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
||||
|
||||
// backend reg interface
|
||||
|
||||
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return "RPC";
|
||||
struct ggml_backend_rpc_reg_context {
|
||||
std::string name;
|
||||
std::vector<ggml_backend_dev_t> devices;
|
||||
};
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
||||
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
||||
return ctx ? ctx->name.c_str() : "RPC";
|
||||
}
|
||||
|
||||
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
||||
return ctx ? ctx->devices.size() : 0;
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
|
||||
if (ctx == nullptr) {
|
||||
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
|
||||
} else {
|
||||
GGML_ASSERT(index < ctx->devices.size());
|
||||
return ctx->devices[index];
|
||||
}
|
||||
}
|
||||
|
||||
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
||||
return (void *)ggml_backend_rpc_add_device;
|
||||
if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
|
||||
return (void *)ggml_backend_rpc_add_server;
|
||||
}
|
||||
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
|
||||
return (void *)ggml_backend_rpc_start_server;
|
||||
@@ -1807,30 +1940,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
||||
return &ggml_backend_rpc_reg;
|
||||
}
|
||||
|
||||
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
||||
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
||||
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
|
||||
if (dev_map.find(endpoint) != dev_map.end()) {
|
||||
return dev_map[endpoint];
|
||||
}
|
||||
|
||||
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
||||
};
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_rpc_device_i,
|
||||
/* .reg = */ ggml_backend_rpc_reg(),
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
dev_map[endpoint] = dev;
|
||||
|
||||
return dev;
|
||||
static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
|
||||
auto sock = get_socket(endpoint);
|
||||
rpc_msg_device_count_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
|
||||
RPC_STATUS_ASSERT(status);
|
||||
return response.device_count;
|
||||
}
|
||||
|
||||
static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
|
||||
/* .get_name = */ ggml_backend_rpc_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_rpc_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
|
||||
static std::unordered_map<std::string, ggml_backend_reg_t> reg_map;
|
||||
static std::mutex mutex;
|
||||
static uint32_t dev_id = 0;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (reg_map.find(endpoint) != reg_map.end()) {
|
||||
return reg_map[endpoint];
|
||||
}
|
||||
uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
|
||||
if (dev_count == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
|
||||
ctx->name = "RPC[" + std::string(endpoint) + "]";
|
||||
for (uint32_t ind = 0; ind < dev_count; ind++) {
|
||||
std::string dev_name = "RPC" + std::to_string(dev_id);
|
||||
std::string dev_desc = std::string(endpoint);
|
||||
ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .device = */ ind,
|
||||
/* .name = */ dev_name,
|
||||
/* .description = */ dev_desc
|
||||
};
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_rpc_device_i,
|
||||
/* .reg = */ ggml_backend_rpc_reg(),
|
||||
/* .context = */ dev_ctx,
|
||||
};
|
||||
ctx->devices.push_back(dev);
|
||||
dev_id++;
|
||||
}
|
||||
ggml_backend_reg_t reg = new ggml_backend_reg {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
/* .iface = */ ggml_backend_rpc_reg_interface,
|
||||
/* .context = */ ctx
|
||||
};
|
||||
reg_map[endpoint] = reg;
|
||||
return reg;
|
||||
}
|
||||
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "concat.hpp"
|
||||
#include "conv.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "count-equal.hpp"
|
||||
#include "cpy.hpp"
|
||||
#include "dequantize.hpp"
|
||||
#include "dmmv.hpp"
|
||||
@@ -28,6 +29,7 @@
|
||||
#include "mmvq.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "pad.hpp"
|
||||
#include "quantize.hpp"
|
||||
#include "quants.hpp"
|
||||
#include "rope.hpp"
|
||||
|
||||
@@ -303,10 +303,6 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
@@ -332,11 +328,6 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_sub(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_count_equal(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_mul(ctx, dst);
|
||||
|
||||
@@ -16,12 +16,6 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_count_equal(const float a, const float b) {
|
||||
return (a == b) ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
@@ -195,8 +195,10 @@ struct optimize_feature {
|
||||
|
||||
struct sycl_device_info {
|
||||
int cc; // compute capability
|
||||
// int nsm; // number of streaming multiprocessors
|
||||
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
|
||||
// number of compute units on a SYCL device.
|
||||
// size_t smpb; // max. shared memory per block
|
||||
size_t smpbo; // max. shared memory per block (with opt-in)
|
||||
bool vmm; // virtual memory support
|
||||
size_t total_vram;
|
||||
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
|
||||
@@ -416,13 +418,6 @@ static __dpct_inline__ float warp_reduce_sum(float x,
|
||||
const sycl::nd_item<3>& item_ct1) {
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
/*
|
||||
DPCT1096:98: The right-most dimension of the work-group used in the SYCL
|
||||
kernel that calls this function may be less than "32". The function
|
||||
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
|
||||
CPU device. Modify the size of the work-group to ensure that the value
|
||||
of the right-most dimension is a multiple of "32".
|
||||
*/
|
||||
x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask);
|
||||
}
|
||||
return x;
|
||||
@@ -440,17 +435,67 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ int warp_reduce_sum(int x) {
|
||||
return sycl::reduce_over_group(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x, offset, width);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
a.x() += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a.x(), offset,
|
||||
width);
|
||||
a.y() += dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a.y(), offset,
|
||||
width);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
a = a + dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), a, offset,
|
||||
width);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
static constexpr int ggml_sycl_get_physical_warp_size() {
|
||||
// todo: for old iGPU + dGPU case, need to be changed.
|
||||
return WARP_SIZE;
|
||||
}
|
||||
|
||||
template <int width = WARP_SIZE>
|
||||
static __dpct_inline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
|
||||
sycl::ext::oneapi::this_work_item::get_sub_group(), x,
|
||||
offset, width));
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float warp_reduce_max(float x,
|
||||
const sycl::nd_item<3>& item_ct1) {
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
/*
|
||||
DPCT1096:97: The right-most dimension of the work-group used in the SYCL
|
||||
kernel that calls this function may be less than "32". The function
|
||||
"dpct::permute_sub_group_by_xor" may return an unexpected result on the
|
||||
CPU device. Modify the size of the work-group to ensure that the value
|
||||
of the right-most dimension is a multiple of "32".
|
||||
*/
|
||||
x = sycl::fmax(x, dpct::permute_sub_group_by_xor(
|
||||
item_ct1.get_sub_group(), x, mask));
|
||||
}
|
||||
@@ -558,4 +603,18 @@ struct scope_op_debug_print {
|
||||
std::string_view func_suffix;
|
||||
};
|
||||
|
||||
static __dpct_inline__ float get_alibi_slope(const float max_bias,
|
||||
const uint32_t h,
|
||||
const uint32_t n_head_log2,
|
||||
const float m0,
|
||||
const float m1) {
|
||||
if (max_bias <= 0.0f) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
return dpct::pow(base, exph);
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
#include "count-equal.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
template <typename T>
|
||||
static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
|
||||
int64_t *__restrict__ dst, const int64_t dk,
|
||||
const int64_t k) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i0 = (int64_t)item_ct1.get_group(2) * dk;
|
||||
const int64_t i1 = sycl::min(i0 + dk, k);
|
||||
|
||||
int nequal = 0;
|
||||
|
||||
for (int64_t i = i0 + item_ct1.get_local_id(2); i < i1; i += WARP_SIZE) {
|
||||
const T xi = x[i];
|
||||
const T yi = y[i];
|
||||
nequal += xi == yi;
|
||||
}
|
||||
|
||||
nequal = warp_reduce_sum(nequal);
|
||||
|
||||
if (item_ct1.get_local_id(2) != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
|
||||
(int *)dst, nequal);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
int64_t * dst_d = (int64_t *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
const int id = get_current_device_id();
|
||||
const int nsm = ggml_sycl_info().devices[id].nsm;
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
|
||||
const int64_t dne =
|
||||
GGML_PAD((ne + 4 * nsm - 1) / (4 * nsm), SYCL_COUNT_EQUAL_CHUNK_SIZE);
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->memset(dst_d, 0, ggml_nbytes(dst))));
|
||||
|
||||
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dpct::dim3 block_nums(
|
||||
std::min((int64_t)4 * nsm, (ne + SYCL_COUNT_EQUAL_CHUNK_SIZE - 1) /
|
||||
SYCL_COUNT_EQUAL_CHUNK_SIZE),
|
||||
1, 1);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_I32: {
|
||||
const int *src0_d = (const int *)src0->data;
|
||||
const int *src1_d = (const int *)src1->data;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
count_equal(src0_d, src1_d, dst_d, dne, ne);
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
#ifndef GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#define GGML_SYCL_COUNT_EQUAL_HPP
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_COUNT_EQUAL_CHUNK_SIZE 128
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif //GGML_SYCL_COUNT_EQUAL_HPP
|
||||
@@ -277,6 +277,26 @@ namespace dpct
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// COPY from DPCT head files
|
||||
/// dim3 is used to store 3 component dimensions.
|
||||
class dim3 {
|
||||
public:
|
||||
unsigned x, y, z;
|
||||
|
||||
constexpr dim3(unsigned x = 1, unsigned y = 1, unsigned z = 1)
|
||||
: x(x), y(y), z(z) {}
|
||||
|
||||
dim3(const sycl::id<3> &r) : dim3(r[2], r[1], r[0]) {}
|
||||
|
||||
operator sycl::range<3>() const { return sycl::range<3>(z, y, x); }
|
||||
}; // namespace dim3
|
||||
|
||||
inline dim3 operator*(const dim3 &a, const dim3 &b) {
|
||||
return dim3{a.x * b.x, a.y * b.y, a.z * b.z};
|
||||
}
|
||||
// COPY from DPCT head files
|
||||
|
||||
|
||||
/// Pitched 2D/3D memory data.
|
||||
class pitched_data
|
||||
{
|
||||
|
||||
@@ -328,26 +328,6 @@ static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void pad(const T *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
||||
item_ct1.get_group(0) * ne00 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
dst[offset_dst] = static_cast<T>(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void clamp(const T * x, T * dst, const float min, const float max, const int k,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
@@ -431,18 +411,6 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||
const int ne01, const int ne02, const int ne0,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
@@ -596,40 +564,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
}
|
||||
}
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
|
||||
(int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
@@ -919,14 +853,6 @@ static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_te
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
|
||||
queue_ptr stream) {
|
||||
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
float min_val;
|
||||
float max_val;
|
||||
@@ -1119,10 +1045,6 @@ void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
|
||||
@@ -67,8 +67,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -85,7 +85,10 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
||||
|
||||
info.devices[i].cc =
|
||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||
info.devices[i].nsm = prop.get_max_compute_units();
|
||||
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
|
||||
info.devices[i].smpbo = prop.get_local_mem_size();
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
}
|
||||
|
||||
@@ -1511,60 +1514,70 @@ static inline void ggml_sycl_swap(T & a, T & b) {
|
||||
template <ggml_sort_order order>
|
||||
__dpct_inline__ static void
|
||||
k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
|
||||
const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
|
||||
const int tasks_per_thread, const sycl::nd_item<3> &item_ct1,
|
||||
uint8_t *dpct_local) {
|
||||
// bitonic sort
|
||||
int col = item_ct1.get_local_id(2);
|
||||
int col_index = item_ct1.get_local_id(2);
|
||||
int row = item_ct1.get_group(1);
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const float * x_row = x + row * ncols;
|
||||
auto dst_row = (int *)dpct_local;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
for (int i=0;i<tasks_per_thread;i++){
|
||||
int col = col_index*tasks_per_thread+i;
|
||||
dst_row[col] = col;
|
||||
}
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] > x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] <
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols &&
|
||||
(order == GGML_SORT_ORDER_ASC
|
||||
? x_row[dst_row[col]] < x_row[dst_row[ixj]]
|
||||
: x_row[dst_row[col]] >
|
||||
x_row[dst_row[ixj]]))) {
|
||||
ggml_sycl_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
/*
|
||||
DPCT1118:1: SYCL group functions and algorithms must be encountered
|
||||
in converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
for (int i = 0; i < tasks_per_thread; i++) {
|
||||
int col = col_index * tasks_per_thread + i;
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
@@ -1737,11 +1750,20 @@ static int next_power_of_2(int x) {
|
||||
|
||||
static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
const int nrows, ggml_sort_order order,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, ncols_pad);
|
||||
int nth = 1;
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
while (nth < ncols_pad && nth < max_block_size)
|
||||
nth *= 2;
|
||||
if (nth > max_block_size)
|
||||
nth = max_block_size;
|
||||
|
||||
const int tasks_per_thread = ncols_pad / nth;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
@@ -1754,8 +1776,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -1768,8 +1791,9 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
x, dst, ncols, ncols_pad, tasks_per_thread, item_ct1,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
});
|
||||
});
|
||||
@@ -2141,7 +2165,8 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
|
||||
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order,
|
||||
main_stream, ctx.device);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -3741,6 +3766,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_SOFT_MAX:
|
||||
ggml_sycl_op_soft_max(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
ggml_sycl_op_soft_max_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
break;
|
||||
@@ -3778,6 +3806,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
return true;
|
||||
} catch (sycl::exception & e) {
|
||||
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
std::cerr << "Error OP "<<ggml_op_name(dst->op)<< std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
@@ -4386,19 +4415,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support batching
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support attention sinks [TAG_ATTN_SINKS]
|
||||
if (op->src[2]) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX_BACK: {
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
||||
return max_bias == 0.0f;
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
@@ -4412,8 +4437,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_PAD:
|
||||
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
|
||||
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
//#include "common.hpp"
|
||||
#include "pad.hpp"
|
||||
|
||||
static void pad_f32(const float * src, float * dst,
|
||||
const int lp0, const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3, const int rp3,
|
||||
const int ne0, const int ne1, const int ne2, const int ne3) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
int i0 = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
int i1 = item_ct1.get_group(1);
|
||||
int i2 = item_ct1.get_group(0) % ne2;
|
||||
int i3 = item_ct1.get_group(0) / ne2;
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
// operation
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
|
||||
(i1 >= lp1 && i1 < ne1 - rp1) &&
|
||||
(i2 >= lp2 && i2 < ne2 - rp2) &&
|
||||
(i3 >= lp3 && i3 < ne3 - rp3)) {
|
||||
const int64_t i00 = i0 - lp0;
|
||||
const int64_t i01 = i1 - lp1;
|
||||
const int64_t i02 = i2 - lp2;
|
||||
const int64_t i03 = i3 - lp3;
|
||||
const int64_t ne02 = ne2 - lp2 - rp2;
|
||||
const int64_t ne01 = ne1 - lp1 - rp1;
|
||||
const int64_t ne00 = ne0 - lp0 - rp0;
|
||||
|
||||
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) +
|
||||
i02 * (ne00 * ne01) + i01 * ne00 + i00;
|
||||
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
static void pad_f32_sycl(const float *src, float *dst, const int lp0,
|
||||
const int rp0, const int lp1, const int rp1,
|
||||
const int lp2, const int rp2, const int lp3,
|
||||
const int rp3, const int ne0, const int ne1,
|
||||
const int ne2, const int ne3,
|
||||
dpct::queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1,
|
||||
ne2, ne3);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
|
||||
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
|
||||
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
|
||||
|
||||
pad_f32_sycl(src0_d, dst_d,
|
||||
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2025 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_PAD_HPP
|
||||
#define GGML_SYCL_PAD_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_PAD_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_PAD_HPP
|
||||
+328
-163
@@ -1,37 +1,94 @@
|
||||
#include "softmax.hpp"
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <cmath>
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int rowx = item_ct1.get_group(2);
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
template <typename T> static __dpct_inline__ float t2f32(T val) {
|
||||
return (float) val;
|
||||
}
|
||||
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
template <> float __dpct_inline__ t2f32<sycl::half>(sycl::half val) {
|
||||
return sycl::vec<sycl::half, 1>(val)
|
||||
.convert<float, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
|
||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
struct soft_max_params {
|
||||
|
||||
int64_t nheads;
|
||||
uint32_t n_head_log2;
|
||||
int64_t ncols;
|
||||
int64_t nrows_x;
|
||||
int64_t nrows_y;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
int64_t nb11;
|
||||
int64_t nb12;
|
||||
int64_t nb13;
|
||||
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
};
|
||||
|
||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static void soft_max_f32(const float * x,
|
||||
const T * mask,
|
||||
const float * sinks,
|
||||
float * dst,
|
||||
const soft_max_params p,
|
||||
uint8_t * dpct_local) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||
const int block_size = block_size_template == 0
|
||||
? item_ct1.get_local_range(2)
|
||||
: block_size_template;
|
||||
const int nthreads = block_size;
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = rowx/nrows_y; // head index
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
const int64_t i03 = item_ct1.get_group(0);
|
||||
const int64_t i02 = item_ct1.get_group(1);
|
||||
const int64_t i01 = item_ct1.get_group(2);
|
||||
|
||||
slope = sycl::pow(base, float(exp));
|
||||
}
|
||||
//TODO: noncontigous inputs/outputs
|
||||
const int rowx = item_ct1.get_group(2) +
|
||||
item_ct1.get_group(1) * item_ct1.get_group_range(2) +
|
||||
item_ct1.get_group(0) * item_ct1.get_group_range(2) *
|
||||
item_ct1.get_group_range(1);
|
||||
|
||||
float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
||||
float max_val = -INFINITY;
|
||||
const int64_t i11 = i01;
|
||||
const int64_t i12 = i02 % p.ne12;
|
||||
const int64_t i13 = i03 % p.ne13;
|
||||
|
||||
x += int64_t(rowx)*ncols;
|
||||
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
|
||||
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||
|
||||
float * buf_iw = (float *) dpct_local;
|
||||
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst;
|
||||
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
@@ -39,42 +96,35 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? slope*static_cast<float>(mask[iy]) : 0.0f);
|
||||
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = sycl::max(max_val, val);
|
||||
max_val = sycl::max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = -INFINITY;
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
buf_iw[lane_id] = -INFINITY;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
item_ct1.barrier();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = max_val;
|
||||
buf_iw[warp_id] = max_val;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
max_val = buf[lane_id];
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||
}
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
}
|
||||
item_ct1.barrier();
|
||||
|
||||
max_val = buf_iw[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
float tmp = 0.0f; // partial sum
|
||||
|
||||
float tmp = 0.f;
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -82,32 +132,33 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
item_ct1.barrier();
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = 0.f;
|
||||
buf_iw[lane_id] = 0.0f;
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
buf[lane_id + i * WARP_SIZE] = 0.f;
|
||||
buf_iw[lane_id + i * WARP_SIZE] = 0.f;
|
||||
}
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
item_ct1.barrier();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = tmp;
|
||||
buf_iw[warp_id] = tmp;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
item_ct1.barrier();
|
||||
|
||||
tmp = buf[lane_id];
|
||||
tmp = buf_iw[lane_id];
|
||||
for (size_t i = 1; i < nreduce; i += 1) {
|
||||
tmp += buf[lane_id + i * WARP_SIZE];
|
||||
tmp += buf_iw[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.f / tmp;
|
||||
if (sinks) {
|
||||
tmp += sycl::native::exp(sinks[i02] - max_val);
|
||||
}
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
@@ -117,145 +168,259 @@ static void soft_max_f32(const float * x, const T * mask, float * dst, const int
|
||||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
|
||||
const int ncols, const float scale) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int rowx = item_ct1.get_group(2);
|
||||
|
||||
grad += int64_t(rowx)*ncols;
|
||||
dstf += int64_t(rowx)*ncols;
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dgf_dot += dstf[col]*grad[col];
|
||||
}
|
||||
|
||||
dgf_dot = warp_reduce_sum(dgf_dot);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
|
||||
}
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, queue_ptr stream) {
|
||||
template <int... Ns, typename T>
|
||||
static void launch_soft_max_kernels(const float * x,
|
||||
const T * mask,
|
||||
const float * sinks,
|
||||
float * dst,
|
||||
const soft_max_params & p,
|
||||
dpct::queue_ptr stream,
|
||||
dpct::dim3 block_dims,
|
||||
dpct::dim3 block_nums,
|
||||
size_t nbytes_shared)
|
||||
{
|
||||
auto launch_kernel = [=](auto I) -> bool {
|
||||
constexpr int ncols = decltype(I)::value;
|
||||
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||
if (p.ncols == ncols) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||
sycl::range<1>(nbytes_shared), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
|
||||
WARP_SIZE)]] {
|
||||
soft_max_f32<true, ncols, block>(
|
||||
x, mask, sinks, dst, p,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// unary fold over launch_kernel
|
||||
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||
return;
|
||||
}
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||
sycl::range<1>(nbytes_shared), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
get_pointer(local_buf_acc));
|
||||
});
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<true, 0, 0>(
|
||||
x, mask, sinks, dst, p,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_sycl(const float * x, const T * mask,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
queue_ptr stream, int device) {
|
||||
template <typename T>
|
||||
static void soft_max_f32_sycl(const float *x, const T *mask,
|
||||
const float *sinks, float *dst,
|
||||
const soft_max_params ¶ms,
|
||||
dpct::queue_ptr stream, int device) {
|
||||
int nth = WARP_SIZE;
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||
if (nth>max_block_size) nth = max_block_size;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, 1, nrows_x);
|
||||
const size_t n_val_tmp = nth / WARP_SIZE;
|
||||
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + n_val_tmp);
|
||||
const dpct::dim3 block_dims(nth, 1, 1);
|
||||
const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||
const size_t nbytes_shared =
|
||||
(GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float);
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
const int id = get_current_device_id();
|
||||
const size_t smpbo = ggml_sycl_info().devices[id].smpbo;
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||
if (ncols_x > max_block_size) {
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
return;
|
||||
}
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
}
|
||||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
|
||||
x, mask, sinks, dst, params, stream, block_dims, block_nums,
|
||||
nbytes_shared);
|
||||
} else {
|
||||
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, WARP_SIZE, stream);
|
||||
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||
sycl::range<1>(nbytes_shared_low), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
soft_max_f32<false, 0, 0>(
|
||||
x, mask, sinks, dst, params,
|
||||
dpct_local_acc_ct1
|
||||
.get_multi_ptr<sycl::access::decorated::no>()
|
||||
.get());
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static void soft_max_back_f32_sycl(const float * grad,
|
||||
const float * dstf,
|
||||
float * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
const float scale,
|
||||
dpct::queue_ptr stream) {
|
||||
const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dpct::dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
soft_max_back_f32(grad, dstf, dst, ncols, scale);
|
||||
GGML_UNUSED(item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
// src1 contains mask and it is optional
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(dst->src[0]);
|
||||
const int64_t nrows_y = dst->src[0]->ne[1];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
ggml_sycl_set_device(ctx.device);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||
|
||||
if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) {
|
||||
const sycl::half * src1_dd = static_cast<sycl::half *>(dst->src[1]->data);
|
||||
soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
|
||||
main_stream, ctx.device);
|
||||
} else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) {
|
||||
const float * src1_dd = static_cast<const float *>(dst->src[1]->data);
|
||||
soft_max_f32_sycl<float>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
|
||||
soft_max_params params = {};
|
||||
params.nheads = src0->ne[2];
|
||||
params.n_head_log2 = n_head_log2;
|
||||
params.ncols = ne00;
|
||||
params.nrows_x = nrows_x;
|
||||
params.nrows_y = nrows_y;
|
||||
params.ne00 = src0->ne[0];
|
||||
params.ne01 = src0->ne[1];
|
||||
params.ne02 = src0->ne[2];
|
||||
params.ne03 = src0->ne[3];
|
||||
params.nb11 = nb11;
|
||||
params.nb12 = nb12;
|
||||
params.nb13 = nb13;
|
||||
params.ne12 = ne12;
|
||||
params.ne13 = ne13;
|
||||
params.scale = scale;
|
||||
params.max_bias = max_bias;
|
||||
params.m0 = m0;
|
||||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d,
|
||||
(const float *)src2_d, dst_d, params, stream,
|
||||
ctx.device);
|
||||
} else {
|
||||
/* mask unavailable */
|
||||
soft_max_f32_sycl<float>(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d,
|
||||
dst_d, params, stream, ctx.device);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
const ggml_tensor * src0 = dst->src[0]; // grad
|
||||
const ggml_tensor * src1 = dst->src[1]; // forward pass output
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
GGML_ASSERT(max_bias == 0.0f);
|
||||
|
||||
soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,10 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_SOFT_MAX_BLOCK_SIZE 1024
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst);
|
||||
|
||||
void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_SOFTMAX_HPP
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.19)
|
||||
cmake_policy(SET CMP0114 NEW)
|
||||
cmake_policy(SET CMP0116 NEW)
|
||||
|
||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
||||
|
||||
@@ -54,25 +55,25 @@ if (Vulkan_FOUND)
|
||||
# Test all shader extensions
|
||||
test_shader_extension_support(
|
||||
"GL_KHR_cooperative_matrix"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp"
|
||||
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_NV_cooperative_matrix2"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp"
|
||||
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_integer_dot_product"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
|
||||
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
test_shader_extension_support(
|
||||
"GL_EXT_bfloat16"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp"
|
||||
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
|
||||
)
|
||||
|
||||
@@ -160,7 +161,6 @@ if (Vulkan_FOUND)
|
||||
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
|
||||
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
|
||||
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
|
||||
set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
|
||||
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
|
||||
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
|
||||
|
||||
@@ -176,24 +176,35 @@ if (Vulkan_FOUND)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_header}
|
||||
${_ggml_vk_source}
|
||||
|
||||
COMMAND ${_ggml_vk_genshaders_cmd}
|
||||
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
||||
--input-dir ${_ggml_vk_input_dir}
|
||||
--output-dir ${_ggml_vk_output_dir}
|
||||
--target-hpp ${_ggml_vk_header}
|
||||
--target-cpp ${_ggml_vk_source}
|
||||
--no-clean
|
||||
|
||||
DEPENDS ${_ggml_vk_shader_files}
|
||||
${_ggml_vk_shaders_gen_sources}
|
||||
DEPENDS ${_ggml_vk_shaders_gen_sources}
|
||||
vulkan-shaders-gen
|
||||
|
||||
COMMENT "Generate vulkan shaders"
|
||||
COMMENT "Generate vulkan shaders header"
|
||||
)
|
||||
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
|
||||
|
||||
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
|
||||
foreach (file_full ${_ggml_vk_shader_files})
|
||||
get_filename_component(file ${file_full} NAME)
|
||||
set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_target_cpp}
|
||||
DEPFILE ${_ggml_vk_target_cpp}.d
|
||||
COMMAND ${_ggml_vk_genshaders_cmd}
|
||||
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
||||
--source ${file_full}
|
||||
--output-dir ${_ggml_vk_output_dir}
|
||||
--target-hpp ${_ggml_vk_header}
|
||||
--target-cpp ${_ggml_vk_target_cpp}
|
||||
DEPENDS ${file_full}
|
||||
${_ggml_vk_shaders_gen_sources}
|
||||
vulkan-shaders-gen
|
||||
COMMENT "Generate vulkan shaders for ${file}"
|
||||
)
|
||||
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})
|
||||
endforeach()
|
||||
|
||||
else()
|
||||
message(WARNING "Vulkan not found")
|
||||
|
||||
@@ -393,6 +393,7 @@ struct vk_device_struct {
|
||||
vk::PhysicalDeviceProperties properties;
|
||||
std::string name;
|
||||
uint64_t max_memory_allocation_size;
|
||||
uint64_t max_buffer_size;
|
||||
uint64_t suballocation_block_size;
|
||||
bool fp16;
|
||||
bool bf16;
|
||||
@@ -1563,6 +1564,12 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
static void ggml_backend_vk_free(ggml_backend_t backend);
|
||||
|
||||
static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
|
||||
const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
|
||||
VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
|
||||
return range;
|
||||
}
|
||||
|
||||
// Wait for ctx->fence to be signaled.
|
||||
static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
|
||||
// Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
|
||||
@@ -2012,8 +2019,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
|
||||
|
||||
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
|
||||
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
|
||||
if (size > device->max_memory_allocation_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
|
||||
if (size > device->max_buffer_size) {
|
||||
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
|
||||
}
|
||||
|
||||
vk_buffer buf = std::make_shared<vk_buffer_struct>();
|
||||
@@ -2159,8 +2166,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
|
||||
buf.reset();
|
||||
}
|
||||
|
||||
static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
|
||||
return { buf, 0, VK_WHOLE_SIZE };
|
||||
static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
|
||||
return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
|
||||
}
|
||||
|
||||
static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
|
||||
@@ -2614,8 +2621,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
|
||||
};
|
||||
|
||||
@@ -3855,17 +3860,27 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
||||
|
||||
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
|
||||
device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
|
||||
device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
|
||||
} else if (maintenance4_support) {
|
||||
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
|
||||
} else {
|
||||
device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
|
||||
}
|
||||
|
||||
const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
|
||||
|
||||
if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
|
||||
device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
|
||||
} else if (maintenance4_support) {
|
||||
device->max_buffer_size = props4.maxBufferSize;
|
||||
} else {
|
||||
device->max_buffer_size = device->max_memory_allocation_size;
|
||||
}
|
||||
|
||||
const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
|
||||
|
||||
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
|
||||
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
||||
device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
|
||||
} else {
|
||||
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
|
||||
device->suballocation_block_size = 1024*1024*1024;
|
||||
@@ -6150,9 +6165,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
}
|
||||
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
|
||||
if (
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||
(split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
@@ -6227,7 +6242,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
|
||||
} else if (qx_needs_dequant) {
|
||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
|
||||
@@ -6239,7 +6254,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -6250,7 +6265,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
|
||||
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -6272,14 +6287,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
|
||||
}
|
||||
|
||||
// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
|
||||
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
|
||||
|
||||
// compute
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, pipeline,
|
||||
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
|
||||
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
||||
ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
|
||||
ne01, ne11, ne10,
|
||||
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
|
||||
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
|
||||
@@ -6446,8 +6458,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
|
||||
}
|
||||
if (
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
@@ -6512,7 +6524,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
}
|
||||
|
||||
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
|
||||
}
|
||||
if (y_non_contig) {
|
||||
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
||||
@@ -6521,7 +6533,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -6532,7 +6544,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
|
||||
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -6931,8 +6943,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||
if (
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
@@ -6999,7 +7011,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
}
|
||||
|
||||
if (x_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
|
||||
} else if (qx_needs_dequant) {
|
||||
const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
|
||||
@@ -7012,7 +7024,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -7145,8 +7157,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||
if (
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
|
||||
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
@@ -7212,7 +7224,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
|
||||
if (x_non_contig) {
|
||||
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
|
||||
}
|
||||
if (y_non_contig) {
|
||||
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
|
||||
@@ -7221,7 +7233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||
if (ctx->prealloc_y_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
|
||||
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
|
||||
ctx->prealloc_y_last_tensor_used = src1;
|
||||
}
|
||||
@@ -7457,8 +7469,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
|
||||
aligned = false;
|
||||
}
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
@@ -7498,7 +7508,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
|
||||
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
||||
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
|
||||
GGML_ABORT("Requested preallocation size is too large");
|
||||
}
|
||||
if (ctx->prealloc_size_split_k < split_k_size) {
|
||||
@@ -7620,12 +7630,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
|
||||
},
|
||||
// We only use split_k when group query attention is enabled, which means
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
@@ -7637,21 +7647,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{
|
||||
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
|
||||
},
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
ctx->prealloc_split_k_need_sync = true;
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
|
||||
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
||||
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
|
||||
},
|
||||
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
}
|
||||
@@ -8360,18 +8370,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
|
||||
uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
|
||||
uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
|
||||
uint64_t d_sz = ggml_type_size(dst->type) * ned;
|
||||
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
|
||||
// Workaround for tiny tensor inputs on ROPE
|
||||
if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
|
||||
y_sz = VK_WHOLE_SIZE;
|
||||
}
|
||||
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
if(!src0_uma) {
|
||||
@@ -8396,26 +8396,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
if (op_supports_incontiguous) {
|
||||
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
|
||||
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
|
||||
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
|
||||
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
|
||||
|
||||
if (x_buf_offset + x_sz >= d_X->size) {
|
||||
x_sz = VK_WHOLE_SIZE;
|
||||
}
|
||||
if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
|
||||
y_sz = VK_WHOLE_SIZE;
|
||||
}
|
||||
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
|
||||
z_sz = VK_WHOLE_SIZE;
|
||||
}
|
||||
if (d_buf_offset + d_sz >= d_D->size) {
|
||||
d_sz = VK_WHOLE_SIZE;
|
||||
}
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
// Single call if dimension 2 is contiguous
|
||||
@@ -8606,19 +8586,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
break;
|
||||
}
|
||||
|
||||
if (!op_supports_incontiguous) {
|
||||
if (x_sz != VK_WHOLE_SIZE) {
|
||||
x_sz *= ne02 * ne03;
|
||||
uint64_t x_sz, y_sz, z_sz, d_sz;
|
||||
|
||||
if (op_supports_incontiguous) {
|
||||
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
|
||||
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
|
||||
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
|
||||
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
|
||||
|
||||
if (x_buf_offset + x_sz >= d_X->size) {
|
||||
x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
|
||||
}
|
||||
if (use_src1 && y_sz != VK_WHOLE_SIZE) {
|
||||
y_sz *= ne12 * ne13;
|
||||
if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
|
||||
y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
|
||||
}
|
||||
if (use_src2 && z_sz != VK_WHOLE_SIZE) {
|
||||
z_sz *= ne22 * ne23;
|
||||
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
|
||||
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
|
||||
}
|
||||
if (d_sz != VK_WHOLE_SIZE) {
|
||||
d_sz *= ned2 * ned3;
|
||||
if (d_buf_offset + d_sz >= d_D->size) {
|
||||
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
|
||||
}
|
||||
} else {
|
||||
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
|
||||
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
|
||||
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
|
||||
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
|
||||
}
|
||||
|
||||
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
|
||||
@@ -8628,7 +8620,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
|
||||
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
|
||||
vk_subbuffer{ d_D, d_buf_offset, d_sz },
|
||||
vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
|
||||
ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
|
||||
}, pc, elements);
|
||||
} else if (op == GGML_OP_GLU) {
|
||||
// Empty src1 is possible in glu, but the shader needs a buffer
|
||||
@@ -8821,18 +8813,18 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
static_assert(MAX_PARAMETER_COUNT == 12);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
|
||||
vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
|
||||
ggml_vk_subbuffer(ctx, buf[0], offset[0]),
|
||||
ggml_vk_subbuffer(ctx, buf[1], offset[1]),
|
||||
ggml_vk_subbuffer(ctx, buf[2], offset[2]),
|
||||
ggml_vk_subbuffer(ctx, buf[3], offset[3]),
|
||||
ggml_vk_subbuffer(ctx, buf[4], offset[4]),
|
||||
ggml_vk_subbuffer(ctx, buf[5], offset[5]),
|
||||
ggml_vk_subbuffer(ctx, buf[6], offset[6]),
|
||||
ggml_vk_subbuffer(ctx, buf[7], offset[7]),
|
||||
ggml_vk_subbuffer(ctx, buf[8], offset[8]),
|
||||
ggml_vk_subbuffer(ctx, buf[9], offset[9]),
|
||||
ggml_vk_subbuffer(ctx, buf[10], offset[10]),
|
||||
ggml_vk_subbuffer(ctx, buf[11], offset[11]),
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
@@ -10006,7 +9998,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||
ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
||||
ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
|
||||
m, n, k,
|
||||
k, k, m, k*m, k*n, m*n,
|
||||
split_k, batch, batch, batch, 1, 1, n
|
||||
@@ -10317,7 +10309,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||
//
|
||||
// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
|
||||
// ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
|
||||
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
|
||||
// ggml_vk_ctx_end(subctx);
|
||||
//
|
||||
// auto begin = std::chrono::high_resolution_clock::now();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "types.comp"
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#version 450
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#include "types.comp"
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user