mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-27 16:17:40 +02:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f7c1df6502 | |||
| 9bebfcb4bc | |||
| 0b6529d818 | |||
| c299a92c38 | |||
| 0275c0f800 | |||
| 83d385b429 | |||
| 050ee92d04 | |||
| 3fc4e10527 | |||
| 5d8ccdf9d1 | |||
| 024930c6ad | |||
| 5397c36194 | |||
| e7ea94afcb | |||
| 96183e9820 | |||
| 487a6cc164 | |||
| 5a6a0dd7e1 | |||
| ded1561b42 |
@@ -145,7 +145,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
# ==============================================================================
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
@@ -156,7 +156,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -115,7 +115,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -124,7 +124,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -153,7 +153,7 @@ FROM base AS server
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -126,7 +126,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2.1
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases
|
||||
ARG IGC_VERSION=v2.34.4
|
||||
ARG IGC_VERSION_FULL=2_2.34.4+21428
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
|
||||
ARG IGC_VERSION=v2.36.3
|
||||
ARG IGC_VERSION_FULL=2_2.36.3+21719
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.22.38646.4
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.22.38646.4-0
|
||||
ARG IGDGMM_VERSION=22.10.0
|
||||
|
||||
# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases
|
||||
@@ -214,7 +214,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -225,7 +225,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -138,7 +138,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
|
||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||
|
||||
@@ -138,7 +138,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -118,7 +118,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -108,7 +108,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -39,8 +39,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -266,8 +266,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -446,8 +446,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -506,8 +506,11 @@ jobs:
|
||||
cmake -B build/ReleaseOV -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENVINO=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }}
|
||||
cmake --build build/ReleaseOV --config Release -j $(nproc)
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }} \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build/ReleaseOV --config Release --parallel
|
||||
|
||||
- name: ccache-clear
|
||||
uses: ./.github/actions/ccache-clear
|
||||
@@ -521,8 +524,26 @@ jobs:
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/ReleaseOV/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C ./build/ReleaseOV/bin .
|
||||
dest=./build/ReleaseOV/bin
|
||||
OPENVINO_ROOT=./openvino_toolkit
|
||||
ov_lib="$OPENVINO_ROOT/runtime/lib/intel64"
|
||||
|
||||
# Bundle OpenVINO runtime libs + TBB. Binaries built with RPATH=$ORIGIN
|
||||
# load these siblings without setupvars.sh / LD_LIBRARY_PATH.
|
||||
cp -P "$ov_lib"/libopenvino.so* \
|
||||
"$ov_lib"/libopenvino_c.so* \
|
||||
"$ov_lib"/libopenvino_*_plugin.so \
|
||||
"$ov_lib"/libopenvino_intel_npu_compiler*.so \
|
||||
"$OPENVINO_ROOT"/runtime/3rdparty/tbb/lib/*.so* \
|
||||
"$dest"
|
||||
cp -P /usr/lib/x86_64-linux-gnu/libOpenCL.so.1* "$dest" 2>/dev/null || true
|
||||
cp "$ov_lib"/cache.json "$dest" 2>/dev/null || true
|
||||
|
||||
# OpenVINO licensing
|
||||
cp -r "$OPENVINO_ROOT"/docs/licensing "$dest"/openvino-licensing
|
||||
|
||||
cp LICENSE "$dest"
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C "$dest" .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
@@ -531,6 +552,9 @@ jobs:
|
||||
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
|
||||
|
||||
windows-openvino:
|
||||
needs: [check-release]
|
||||
if: ${{ needs.check-release.outputs.should_release == 'true' }}
|
||||
|
||||
runs-on: windows-2022
|
||||
|
||||
outputs:
|
||||
@@ -538,8 +562,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -607,7 +631,9 @@ jobs:
|
||||
-A x64 ^
|
||||
-DCMAKE_BUILD_TYPE=Release ^
|
||||
-DGGML_OPENVINO=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake ^
|
||||
${{ env.CMAKE_ARGS }}
|
||||
|
||||
cmake --build build\ReleaseOV --config Release -- /m
|
||||
|
||||
@@ -624,8 +650,29 @@ jobs:
|
||||
id: pack_artifacts
|
||||
shell: powershell
|
||||
run: |
|
||||
Copy-Item LICENSE .\build\ReleaseOV\bin\
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip .\build\ReleaseOV\bin\*
|
||||
# Locate the extracted OpenVINO toolkit root (same pattern as the Build step).
|
||||
$OPENVINO_ROOT = (Get-ChildItem -Directory openvino_toolkit | Select-Object -First 1).FullName
|
||||
if (-not $OPENVINO_ROOT) {
|
||||
Write-Error "OpenVINO toolkit folder not found under .\openvino_toolkit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$dest = ".\build\ReleaseOV\bin\Release"
|
||||
|
||||
$ovBin = Join-Path $OPENVINO_ROOT 'runtime\bin\intel64\Release'
|
||||
Copy-Item -Path (Join-Path $ovBin '*.dll') -Destination $dest -Force
|
||||
Copy-Item -Path (Join-Path $ovBin 'cache.json') -Destination $dest -Force
|
||||
|
||||
$tbbBin = Join-Path $OPENVINO_ROOT 'runtime\3rdparty\tbb\bin'
|
||||
Copy-Item -Path (Join-Path $tbbBin 'tbb*.dll') -Destination $dest -Force
|
||||
|
||||
# OpenVINO licensing
|
||||
$licensingDest = Join-Path $dest 'openvino-licensing'
|
||||
New-Item -ItemType Directory -Force -Path $licensingDest | Out-Null
|
||||
Copy-Item -Path (Join-Path $OPENVINO_ROOT 'docs\licensing\*') -Destination $licensingDest -Recurse -Force
|
||||
|
||||
Copy-Item LICENSE $dest
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip $dest\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
|
||||
+1
-1
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
|
||||
### Untrusted environments or networks
|
||||
|
||||
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
|
||||
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
|
||||
* Encrypt your data if sending it over the network.
|
||||
|
||||
|
||||
+8
-4
@@ -50,6 +50,7 @@ struct command {
|
||||
std::vector<std::string> aliases;
|
||||
bool hidden;
|
||||
int (*func)(int, char **);
|
||||
bool flags = false; // allow --name
|
||||
};
|
||||
|
||||
#ifdef LLAMA_INSTALL_BUILD
|
||||
@@ -69,9 +70,9 @@ static const command cmds[] = {
|
||||
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
|
||||
{"quantize", "Quantize a model", {}, true, llama_quantize },
|
||||
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
|
||||
{"version", "Show version", {}, false, version },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
|
||||
{"help", "Show available commands", {}, false, help },
|
||||
{"version", "Show version", {}, false, version, true },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
|
||||
{"help", "Show available commands", {}, false, help, true },
|
||||
};
|
||||
|
||||
#undef UPDATE_HIDDEN
|
||||
@@ -108,7 +109,10 @@ static int help(int argc, char ** argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool matches(const std::string & arg, const command & cmd) {
|
||||
static bool matches(std::string arg, const command & cmd) {
|
||||
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
|
||||
arg.erase(0, 2);
|
||||
}
|
||||
if (arg == cmd.name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
+45
-7
@@ -352,6 +352,8 @@ static std::string get_default_local_path(const std::string & url) {
|
||||
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
@@ -377,7 +379,15 @@ common_models_handler common_models_handler_init(const common_params & params, l
|
||||
plan = common_download_get_hf_plan(params.model, opts);
|
||||
}
|
||||
|
||||
return common_models_handler{plan, opts};
|
||||
if (!params.speculative.draft.mparams.hf_repo.empty()) {
|
||||
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
|
||||
}
|
||||
|
||||
if (!params.vocoder.model.hf_repo.empty()) {
|
||||
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
|
||||
}
|
||||
|
||||
return common_models_handler{plan, plan_spec, plan_voc, opts};
|
||||
}
|
||||
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
|
||||
@@ -425,7 +435,9 @@ static std::vector<common_download_task> build_url_tasks(const common_params_mod
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
auto & plan = handler.plan;
|
||||
auto & plan = handler.plan;
|
||||
auto & plan_spec = handler.plan_spec;
|
||||
auto & plan_voc = handler.plan_voc;
|
||||
|
||||
auto opts = handler.opts; // copy
|
||||
opts.callback = callback;
|
||||
@@ -484,19 +496,22 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
}
|
||||
|
||||
// handle hf_plan tasks
|
||||
if (!plan.model_files.empty()) {
|
||||
for (size_t i = 0; i < plan.model_files.size(); ++i) {
|
||||
auto & model_file = plan.model_files[i];
|
||||
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
|
||||
for (size_t i = 0; i < model_files.size(); ++i) {
|
||||
auto & model_file = model_files[i];
|
||||
bool is_first = (i == 0);
|
||||
tasks.emplace_back(model_file, opts, [&, is_first]() {
|
||||
if (is_first) {
|
||||
// only use first part as model path
|
||||
params.model.path = hf_cache::finalize_file(model_file);
|
||||
model.path = hf_cache::finalize_file(model_file);
|
||||
} else {
|
||||
hf_cache::finalize_file(model_file);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
if (!plan.model_files.empty()) {
|
||||
add_tasks(plan.model_files, params.model);
|
||||
}
|
||||
if (!plan.mmproj.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mmproj, opts, [&]() {
|
||||
@@ -522,9 +537,31 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
});
|
||||
}
|
||||
|
||||
// handle plan_spec (e.g. --spec-draft-hf)
|
||||
if (!plan_spec.model_files.empty()) {
|
||||
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
|
||||
}
|
||||
|
||||
// handle vocoder plan (e.g. --hf-repo-v)
|
||||
if (!plan_voc.model_files.empty()) {
|
||||
add_tasks(plan_voc.model_files, params.vocoder.model);
|
||||
}
|
||||
|
||||
// run all tasks in parallel
|
||||
if (!params.offline) {
|
||||
common_download_run_tasks(tasks);
|
||||
// if duplicated files are found, only download once (but still call on_done for each task)
|
||||
std::unordered_map<std::string, common_download_task *> unique_tasks;
|
||||
for (auto & task : tasks) {
|
||||
auto it = unique_tasks.find(task.local_path);
|
||||
if (it == unique_tasks.end()) {
|
||||
unique_tasks[task.local_path] = &task;
|
||||
}
|
||||
}
|
||||
std::vector<common_download_task> unique_tasks_vec;
|
||||
for (auto & pair : unique_tasks) {
|
||||
unique_tasks_vec.push_back(*pair.second);
|
||||
}
|
||||
common_download_run_tasks(unique_tasks_vec);
|
||||
}
|
||||
|
||||
// download successful, update params with the downloaded paths
|
||||
@@ -3711,6 +3748,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.draft.mparams.path = value;
|
||||
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -133,6 +133,8 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_models_handler {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
};
|
||||
|
||||
|
||||
@@ -237,8 +237,8 @@ chmod +x ubuntu-llamacpp-ov-install.sh
|
||||
# ============================================
|
||||
set -euo pipefail
|
||||
|
||||
OPENVINO_VERSION_MAJOR="2026.2"
|
||||
OPENVINO_VERSION_FULL="2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR="2026.2.1"
|
||||
OPENVINO_VERSION_FULL="2026.2.1.21919.ede283a88e3"
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OPENVINO_INSTALL_DIR="/opt/intel/openvino_${OPENVINO_VERSION_MAJOR}"
|
||||
@@ -334,7 +334,7 @@ echo " ./build/ReleaseOV/bin/llama-cli -m model.gguf"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -364,8 +364,8 @@ REM ============================================
|
||||
REM llama.cpp OpenVINO Build Script (Ninja)
|
||||
REM ============================================
|
||||
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857"
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2.1"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3"
|
||||
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "VCPKG_DIR=C:\vcpkg"
|
||||
@@ -547,7 +547,7 @@ endlocal
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
+1
-4
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_PATCH 3)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
@@ -340,9 +340,6 @@ set(GGML_PUBLIC_HEADERS
|
||||
include/gguf.h)
|
||||
|
||||
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
||||
#if (GGML_METAL)
|
||||
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
|
||||
#endif()
|
||||
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
|
||||
install(TARGETS ggml-base LIBRARY)
|
||||
|
||||
|
||||
@@ -1551,6 +1551,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
int split_backend_id = split->backend_id;
|
||||
ggml_backend_t split_backend = sched->backends[split_backend_id];
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
// copy the input tensors to the split backend
|
||||
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
|
||||
@@ -1561,15 +1563,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
ggml_backend_tensor_copy(input, input_cpy);
|
||||
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
|
||||
} else {
|
||||
// wait for the split backend to finish using the input before overwriting it
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
|
||||
@@ -1674,6 +1676,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
if (!sched->callback_eval) {
|
||||
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
|
||||
if (ec != GGML_STATUS_SUCCESS) {
|
||||
|
||||
@@ -3192,11 +3192,24 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
|
||||
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
|
||||
|
||||
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
|
||||
// Excluding this path for HIP and MUSA as a precaution.
|
||||
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
|
||||
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
|
||||
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
|
||||
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
|
||||
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
|
||||
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
|
||||
const bool copy_from_host = false;
|
||||
#else
|
||||
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
#endif
|
||||
|
||||
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3207,14 +3220,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
|
||||
|
||||
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
|
||||
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
|
||||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
|
||||
#endif // NDEBUG
|
||||
return false;
|
||||
}
|
||||
|
||||
if (backend_src != backend_dst) {
|
||||
if (copy_from_host) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
|
||||
} else if (backend_src != backend_dst) {
|
||||
// copy on src stream
|
||||
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
|
||||
|
||||
@@ -24,62 +24,119 @@ if (GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
|
||||
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
|
||||
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
|
||||
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
|
||||
|
||||
set(METALLIB_KERNEL_SOURCES
|
||||
kernels/fa.metal
|
||||
kernels/mul_mv.metal
|
||||
kernels/mul_mm.metal
|
||||
kernels/quantize.metal
|
||||
kernels/softmax.metal
|
||||
kernels/norm.metal
|
||||
kernels/unary.metal
|
||||
kernels/binbcast.metal
|
||||
kernels/reduce.metal
|
||||
kernels/tri.metal
|
||||
kernels/ssm.metal
|
||||
kernels/wkv.metal
|
||||
kernels/gated_delta_net.metal
|
||||
kernels/solve_tri.metal
|
||||
kernels/rope.metal
|
||||
kernels/conv.metal
|
||||
kernels/upscale.metal
|
||||
kernels/argsort.metal
|
||||
kernels/pool.metal
|
||||
kernels/misc.metal
|
||||
)
|
||||
|
||||
if (GGML_METAL_EMBED_LIBRARY)
|
||||
enable_language(ASM)
|
||||
|
||||
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
|
||||
|
||||
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
|
||||
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
||||
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
|
||||
# merge ggml-common.h and ggml-metal.metal into a single file
|
||||
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
|
||||
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
|
||||
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
|
||||
set(METALLIB_EMBED_ASM_FILES "")
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
get_filename_component(kind ${src} NAME_WE)
|
||||
# symbol names must be valid C identifiers ('-' is not allowed)
|
||||
string(REPLACE "-" "_" kind_sym ${kind})
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "Embedding Metal library"
|
||||
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
|
||||
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
|
||||
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
|
||||
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
|
||||
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
|
||||
COMMENT "Generate assembly for embedded Metal library"
|
||||
VERBATIM
|
||||
)
|
||||
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
|
||||
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
|
||||
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
|
||||
|
||||
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
|
||||
# only prepend headers that this source actually includes
|
||||
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
|
||||
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
|
||||
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
|
||||
if(_has_dequantize)
|
||||
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
|
||||
endif()
|
||||
if(_has_quantize)
|
||||
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
|
||||
endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${ASM}"
|
||||
# Step 1: concatenate shared headers + this kernel source
|
||||
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
|
||||
# Step 2: remove internal #include and #pragma once
|
||||
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
|
||||
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
|
||||
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
|
||||
# Step 4: inline ggml-metal-impl.h
|
||||
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
|
||||
# Step 5: emit an asm chunk with kind-specific start/end symbols
|
||||
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
|
||||
# section name is limited to 16 chars so we keep it shared
|
||||
# across kinds (__ggml_metallib) and only vary the global symbols.
|
||||
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
|
||||
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
|
||||
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
|
||||
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
|
||||
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
|
||||
DEPENDS ../ggml-common.h ggml-metal-impl.h
|
||||
kernels/common.h kernels/dequantize.h kernels/quantize.h
|
||||
kernels/${kind}.metal
|
||||
COMMENT "Generate embedded Metal library for ${kind}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
|
||||
endforeach()
|
||||
|
||||
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
|
||||
else()
|
||||
# copy metal files to bin directory
|
||||
# copy header files to bin directory
|
||||
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
|
||||
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
|
||||
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
|
||||
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
|
||||
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
|
||||
endforeach()
|
||||
|
||||
if (GGML_METAL_SHADER_DEBUG)
|
||||
# custom command to do the following:
|
||||
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
|
||||
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
|
||||
#
|
||||
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
|
||||
# disabling fast math is needed in order to pass tests/test-backend-ops
|
||||
# note: disabling fast math is needed in order to pass tests/test-backend-ops
|
||||
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
|
||||
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
|
||||
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
|
||||
# note: adding -g causes segmentation fault during compile
|
||||
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
|
||||
set(XC_FLAGS -fno-fast-math -fno-inline)
|
||||
else()
|
||||
set(XC_FLAGS -O3)
|
||||
endif()
|
||||
|
||||
# Append macOS metal versioning flags
|
||||
if (GGML_METAL_MACOSX_VERSION_MIN)
|
||||
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
|
||||
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
|
||||
@@ -90,35 +147,46 @@ else()
|
||||
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
|
||||
endif()
|
||||
|
||||
# Compile each kernel source to .air, then link into default.metallib
|
||||
set(AIR_FILES "")
|
||||
foreach(src ${METALLIB_KERNEL_SOURCES})
|
||||
get_filename_component(name ${src} NAME_WE)
|
||||
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
|
||||
list(APPEND AIR_FILES ${AIR})
|
||||
add_custom_command(
|
||||
OUTPUT ${AIR}
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
|
||||
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
|
||||
COMMENT "Compiling ${src}"
|
||||
VERBATIM
|
||||
)
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
|
||||
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
|
||||
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
|
||||
COMMENT "Compiling Metal kernels"
|
||||
)
|
||||
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
|
||||
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
|
||||
DEPENDS ${AIR_FILES}
|
||||
COMMENT "Linking Metal kernels into default.metallib"
|
||||
)
|
||||
|
||||
# FIXME: only add to the ggml-metal target?
|
||||
add_custom_target(
|
||||
ggml-metal-lib ALL
|
||||
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
)
|
||||
)
|
||||
endif() # GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
if (NOT GGML_METAL_EMBED_LIBRARY)
|
||||
install(
|
||||
FILES src/ggml-metal/ggml-metal.metal
|
||||
PERMISSIONS
|
||||
OWNER_READ
|
||||
OWNER_WRITE
|
||||
GROUP_READ
|
||||
WORLD_READ
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR})
|
||||
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
|
||||
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
|
||||
)
|
||||
|
||||
install(
|
||||
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
install(
|
||||
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
|
||||
DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -94,8 +94,63 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
|
||||
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
|
||||
}
|
||||
|
||||
//
|
||||
// MTLLibrary collection (one library per op-source, compiled separately)
|
||||
//
|
||||
|
||||
// Single source of truth for the per-kind metal libraries. The order here
|
||||
// defines the enum values and every per-kind table below, so adding a library
|
||||
// is a one-line change here (plus adding its source to CMakeLists.txt).
|
||||
// X(suffix, name): name is both the kernels/<name>.metal basename and the
|
||||
// ggml_metallib_<name>_{start,end} embed-symbol stem.
|
||||
#define GGML_METAL_LIBS \
|
||||
X(FA, fa) \
|
||||
X(MUL_MV, mul_mv) \
|
||||
X(MUL_MM, mul_mm) \
|
||||
X(QUANTIZE, quantize) \
|
||||
X(SOFTMAX, softmax) \
|
||||
X(NORM, norm) \
|
||||
X(UNARY, unary) \
|
||||
X(BINBCAST, binbcast) \
|
||||
X(REDUCE, reduce) \
|
||||
X(TRI, tri) \
|
||||
X(SSM, ssm) \
|
||||
X(WKV, wkv) \
|
||||
X(GATED_DELTA_NET, gated_delta_net)\
|
||||
X(SOLVE_TRI, solve_tri) \
|
||||
X(ROPE, rope) \
|
||||
X(CONV, conv) \
|
||||
X(UPSCALE, upscale) \
|
||||
X(ARGSORT, argsort) \
|
||||
X(POOL, pool) \
|
||||
X(MISC, misc)
|
||||
|
||||
enum ggml_metal_lib_kind {
|
||||
#define X(e, s) GGML_METAL_LIB_##e,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
GGML_METAL_LIB_COUNT,
|
||||
};
|
||||
|
||||
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
|
||||
struct ggml_metal_library {
|
||||
id<MTLLibrary> obj;
|
||||
// Per-kind compiled libraries. When single_library is true, the whole library
|
||||
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
|
||||
// objs[0] and the remaining slots are nil.
|
||||
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
|
||||
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
|
||||
|
||||
// Routing table: kernel function name -> objs[] index, populated from each
|
||||
// compiled library's -[MTLLibrary functionNames]. The actual compiled
|
||||
// libraries are the single source of truth for which library owns a kernel,
|
||||
// so adding kernels later requires no manual routing maintenance.
|
||||
// nil in single_library mode (everything resolves to objs[0]).
|
||||
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
|
||||
@@ -103,160 +158,376 @@ struct ggml_metal_library {
|
||||
NSLock * lock;
|
||||
};
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
id<MTLLibrary> library = nil;
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||
// Build the fn_to_lib routing table by querying each compiled library's public
|
||||
// function names. Call once after all per-kind libraries have been compiled.
|
||||
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
|
||||
@autoreleasepool {
|
||||
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
for (NSString * fname in [lib->objs[kind] functionNames]) {
|
||||
index[fname] = @(kind);
|
||||
}
|
||||
}
|
||||
lib->fn_to_lib = index;
|
||||
}
|
||||
}
|
||||
|
||||
// load library
|
||||
//
|
||||
// - first check if the library is embedded
|
||||
// - then check if the library is in the bundle
|
||||
// - if not found, load the source and compile it
|
||||
// - if that fails, return NULL
|
||||
//
|
||||
// TODO: move to a function
|
||||
{
|
||||
const int64_t t_start = ggml_time_us();
|
||||
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
|
||||
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
|
||||
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
|
||||
NSScanner * scanner = [NSScanner scannerWithString:line];
|
||||
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
|
||||
|
||||
NSError * error = nil;
|
||||
NSString * src = nil;
|
||||
if (![scanner scanString:@"#" intoString:NULL] ||
|
||||
![scanner scanString:@"include" intoString:NULL] ||
|
||||
![scanner scanString:@"\"" intoString:NULL]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
NSString * name = nil;
|
||||
if (![scanner scanUpToString:@"\"" intoString:&name]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
extern const char ggml_metallib_start[];
|
||||
extern const char ggml_metallib_end[];
|
||||
if (include_name) {
|
||||
*include_name = name;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
|
||||
#else
|
||||
// Recursively inline `#include "name"` directives. System includes (<...>),
|
||||
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
|
||||
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
|
||||
// guards against double-inclusion.
|
||||
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
|
||||
NSArray<NSString *> * search_paths,
|
||||
NSMutableSet<NSString *> * seen, NSError ** error) {
|
||||
NSString * key = [path stringByStandardizingPath];
|
||||
if ([seen containsObject:key]) {
|
||||
return true;
|
||||
}
|
||||
[seen addObject:key];
|
||||
|
||||
#ifdef SWIFT_PACKAGE
|
||||
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
|
||||
if (!src) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
|
||||
NSFileManager * fm = [NSFileManager defaultManager];
|
||||
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
|
||||
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
|
||||
if ([trimmed isEqualToString:@"#pragma once"]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
|
||||
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
|
||||
}
|
||||
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
// Link to the resource could not be resolved.
|
||||
path_lib_default = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
}
|
||||
NSString * include_name = nil;
|
||||
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
|
||||
NSString * resolved = nil;
|
||||
for (NSString * dir in search_paths) {
|
||||
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
|
||||
if ([fm isReadableFileAtPath:candidate]) {
|
||||
resolved = candidate;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
path_lib_default = nil;
|
||||
}
|
||||
|
||||
path_lib = path_lib_default;
|
||||
if (!resolved) {
|
||||
if (error) {
|
||||
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
|
||||
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
|
||||
userInfo:@{NSLocalizedDescriptionKey: msg}];
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (path_lib != nil) {
|
||||
// pre-compiled library found
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
[dst appendString:line];
|
||||
[dst appendString:@"\n"];
|
||||
}
|
||||
|
||||
library = [device newLibraryWithURL:libURL error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
||||
return true;
|
||||
}
|
||||
|
||||
NSString * path_source;
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
|
||||
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
|
||||
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
|
||||
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
|
||||
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
|
||||
NSArray<NSString *> * search_paths = @[
|
||||
path_kernels,
|
||||
path_base,
|
||||
[path_base stringByDeletingLastPathComponent],
|
||||
];
|
||||
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
|
||||
NSMutableString * src = [[NSMutableString alloc] init];
|
||||
NSMutableSet<NSString *> * seen = [NSMutableSet set];
|
||||
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
|
||||
} else {
|
||||
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
|
||||
[src release];
|
||||
return nil;
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
|
||||
// source for a kind (the helper takes ownership and releases it), or nil with
|
||||
// *err set on failure. On success the objs[] slots are populated and the routing
|
||||
// index is built; on any failure every error is logged and false is returned
|
||||
// (the caller is responsible for freeing `res`).
|
||||
static bool ggml_metal_library_compile_all(
|
||||
ggml_metal_library_t res,
|
||||
id<MTLDevice> device,
|
||||
NSDictionary * prep,
|
||||
NSString * (^source_for_kind)(int kind, NSError ** err),
|
||||
const char * origin) {
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
|
||||
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
|
||||
__block atomic_bool any_failure = false;
|
||||
|
||||
dispatch_group_t group = dispatch_group_create();
|
||||
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
|
||||
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
dispatch_group_async(group, queue, ^{
|
||||
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
NSError * error = nil;
|
||||
|
||||
NSString * src = source_for_kind(kind, &error);
|
||||
if (!src) {
|
||||
err_per_lib[kind] = [error retain];
|
||||
atomic_store(&any_failure, true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (path_source == nil) {
|
||||
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
||||
path_source = @"ggml-metal.metal";
|
||||
}
|
||||
id<MTLLibrary> lib = nil;
|
||||
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!library) {
|
||||
@autoreleasepool {
|
||||
// dictionary of preprocessor macros
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
if (ggml_metal_device_get_props(dev)->has_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
|
||||
}
|
||||
|
||||
if (ggml_metal_device_get_props(dev)->has_tensor) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
//[options setFastMathEnabled:false];
|
||||
lib = [device newLibraryWithSource:src options:options error:&error];
|
||||
|
||||
library = [device newLibraryWithSource:src options:options error:&error];
|
||||
if (error) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return nil;
|
||||
}
|
||||
|
||||
#if !__has_feature(objc_arc)
|
||||
[options release];
|
||||
#endif
|
||||
|
||||
// retain the error before the autorelease pool drains it
|
||||
if (!lib) {
|
||||
err_per_lib[kind] = [error retain];
|
||||
}
|
||||
}
|
||||
|
||||
[src release];
|
||||
|
||||
t_per_lib[kind] = ggml_time_us() - t0;
|
||||
|
||||
if (!lib) {
|
||||
atomic_store(&any_failure, true);
|
||||
return;
|
||||
}
|
||||
|
||||
res->objs[kind] = lib;
|
||||
});
|
||||
}
|
||||
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
|
||||
dispatch_release(group);
|
||||
|
||||
const bool ok = !atomic_load(&any_failure);
|
||||
|
||||
if (ok) {
|
||||
const int64_t t_total = ggml_time_us() - t_start;
|
||||
int64_t t_max = 0;
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
|
||||
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
|
||||
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
|
||||
}
|
||||
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
|
||||
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
|
||||
|
||||
ggml_metal_library_build_index(res);
|
||||
} else {
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
if (err_per_lib[kind]) {
|
||||
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
|
||||
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
|
||||
[err_per_lib[kind] release];
|
||||
}
|
||||
}
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[src release];
|
||||
#endif // GGML_METAL_EMBED_LIBRARY
|
||||
|
||||
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||
}
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
free(err_per_lib);
|
||||
free(t_per_lib);
|
||||
|
||||
res->obj = library;
|
||||
return ok;
|
||||
}
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
if (ggml_metal_device_get_props(dev)->has_bfloat) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
|
||||
}
|
||||
if (ggml_metal_device_get_props(dev)->has_tensor) {
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
|
||||
}
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
|
||||
#endif
|
||||
|
||||
#if GGML_METAL_EMBED_LIBRARY
|
||||
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
|
||||
|
||||
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
|
||||
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
|
||||
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
|
||||
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
|
||||
GGML_METAL_LIBS
|
||||
#undef X
|
||||
};
|
||||
|
||||
const bool ok = ggml_metal_library_compile_all(res, device, prep,
|
||||
^NSString * (int kind, NSError ** err) {
|
||||
(void) err;
|
||||
return [[NSString alloc] initWithBytes:lib_start[kind]
|
||||
length:(lib_end[kind] - lib_start[kind])
|
||||
encoding:NSUTF8StringEncoding];
|
||||
}, "embedded data");
|
||||
|
||||
if (!ok) {
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return res;
|
||||
#else
|
||||
#ifdef SWIFT_PACKAGE
|
||||
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
|
||||
#else
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
#endif
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
NSError * error = nil;
|
||||
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
|
||||
if (path_lib == nil) {
|
||||
// Try to find the resource in the directory where the current binary located.
|
||||
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
|
||||
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
|
||||
|
||||
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
|
||||
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
|
||||
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
|
||||
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
|
||||
// Optionally, if this is a symlink, try to resolve it.
|
||||
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
|
||||
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
|
||||
// It is a relative path, adding the binary directory as directory prefix.
|
||||
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
|
||||
}
|
||||
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
|
||||
// Link to the resource could not be resolved.
|
||||
path_lib_default = nil;
|
||||
} else {
|
||||
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// The resource couldn't be found in the binary's directory.
|
||||
path_lib_default = nil;
|
||||
}
|
||||
|
||||
path_lib = path_lib_default;
|
||||
}
|
||||
|
||||
if (path_lib != nil) {
|
||||
// pre-compiled library found: a single combined default.metallib
|
||||
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
|
||||
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
|
||||
|
||||
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
|
||||
res->single_library = true;
|
||||
if (!res->objs[0]) {
|
||||
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||
return res;
|
||||
}
|
||||
|
||||
// no pre-compiled metallib: fall back to compiling each kernel source separately
|
||||
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
|
||||
|
||||
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
|
||||
if (path_resource) {
|
||||
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
|
||||
}
|
||||
|
||||
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
|
||||
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
|
||||
|
||||
NSString * path_source = nil;
|
||||
if (path_resource) {
|
||||
path_source = [path_resource stringByAppendingPathComponent:rel];
|
||||
} else {
|
||||
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
|
||||
path_source = [bundle pathForResource:stem ofType:@"metal"];
|
||||
}
|
||||
|
||||
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
|
||||
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
|
||||
path_source = rel;
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
|
||||
|
||||
path_per_kind[kind] = [path_source retain];
|
||||
}
|
||||
|
||||
const bool ok = ggml_metal_library_compile_all(res, device, prep,
|
||||
^NSString * (int kind, NSError ** err) {
|
||||
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
|
||||
}, "source");
|
||||
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
[path_per_kind[kind] release];
|
||||
}
|
||||
free(path_per_kind);
|
||||
|
||||
if (!ok) {
|
||||
ggml_metal_library_free(res);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return res;
|
||||
#endif
|
||||
}
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
|
||||
@@ -318,10 +589,11 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
|
||||
return NULL;
|
||||
}
|
||||
|
||||
res->obj = library;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
res->objs[0] = library;
|
||||
res->single_library = true;
|
||||
res->dev = dev;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
res->lock = [NSLock new];
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -331,8 +603,14 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (lib->obj) {
|
||||
[lib->obj release];
|
||||
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
|
||||
if (lib->objs[kind]) {
|
||||
[lib->objs[kind] release];
|
||||
}
|
||||
}
|
||||
|
||||
if (lib->fn_to_lib) {
|
||||
[lib->fn_to_lib release];
|
||||
}
|
||||
|
||||
ggml_metal_pipelines_free(lib->pipelines);
|
||||
@@ -393,11 +671,28 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
|
||||
|
||||
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
|
||||
// route to the library that actually defines this kernel; fn_to_lib is
|
||||
// built from -[MTLLibrary functionNames] so it's always in sync
|
||||
int lib_idx = 0;
|
||||
if (!lib->single_library) {
|
||||
NSNumber * idx = lib->fn_to_lib[base_func];
|
||||
if (!idx) {
|
||||
[lib->lock unlock];
|
||||
|
||||
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
|
||||
|
||||
return res;
|
||||
}
|
||||
lib_idx = [idx intValue];
|
||||
}
|
||||
|
||||
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
|
||||
|
||||
id<MTLFunction> mtl_function;
|
||||
if (!cv) {
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func];
|
||||
mtl_function = [mtl_lib newFunctionWithName:base_func];
|
||||
} else {
|
||||
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
|
||||
}
|
||||
if (!mtl_function) {
|
||||
[lib->lock unlock];
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,232 @@
|
||||
#include "common.h"
|
||||
|
||||
// bitonic sort implementation following the CUDA kernels as reference
|
||||
typedef void (argsort_t)(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_f32_i32(
|
||||
constant ggml_metal_kargs_argsort & args,
|
||||
device const char * src0,
|
||||
device int32_t * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
// bitonic sort
|
||||
const int col = tpitg[0];
|
||||
const int ib = tgpig[0] / args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
||||
|
||||
// initialize indices
|
||||
shmem_i32[col] = i00 + col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= ntg.x; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (shmem_i32[col] >= args.ne00 ||
|
||||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (shmem_i32[ixj] >= args.ne00 ||
|
||||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
||||
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
|
||||
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
|
||||
) {
|
||||
SWAP(shmem_i32[col], shmem_i32[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t i0 = ib*args.top_k;
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (i0 + col < args.ne0 && col < args.top_k) {
|
||||
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
|
||||
|
||||
dst[col] = shmem_i32[col];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
typedef void (argsort_merge_t)(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template<ggml_sort_order order>
|
||||
kernel void kernel_argsort_merge_f32_i32(
|
||||
constant ggml_metal_kargs_argsort_merge & args,
|
||||
device const char * src0,
|
||||
device const int32_t * tmp,
|
||||
device int32_t * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int im = tgpig[0] / args.ne01;
|
||||
const int i01 = tgpig[0] % args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
const int start = im * (2 * args.len);
|
||||
|
||||
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
|
||||
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
|
||||
|
||||
const int total = len0 + len1;
|
||||
|
||||
device const int32_t * tmp0 = tmp + start
|
||||
+ i01*args.ne0
|
||||
+ i02*args.ne0*args.ne01
|
||||
+ i03*args.ne0*args.ne01*args.ne02;
|
||||
|
||||
device const int32_t * tmp1 = tmp0 + args.len;
|
||||
|
||||
dst += start
|
||||
+ i01*args.top_k
|
||||
+ i02*args.top_k*args.ne01
|
||||
+ i03*args.top_k*args.ne01*args.ne02;
|
||||
|
||||
device const float * src0_row = (device const float *)(src0
|
||||
+ args.nb01*i01
|
||||
+ args.nb02*i02
|
||||
+ args.nb03*i03);
|
||||
|
||||
if (total == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int chunk = (total + ntg.x - 1) / ntg.x;
|
||||
|
||||
const int k0 = tpitg.x * chunk;
|
||||
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
|
||||
|
||||
if (k0 >= args.top_k) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int low = k0 > len1 ? k0 - len1 : 0;
|
||||
int high = MIN(k0, len0);
|
||||
|
||||
// binary-search partition (i, j) such that i + j = k
|
||||
while (low < high) {
|
||||
const int mid = (low + high) >> 1;
|
||||
|
||||
const int32_t idx0 = tmp0[mid];
|
||||
const int32_t idx1 = tmp1[k0 - mid - 1];
|
||||
|
||||
const float val0 = src0_row[idx0];
|
||||
const float val1 = src0_row[idx1];
|
||||
|
||||
bool take_left;
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
int i = low;
|
||||
int j = k0 - i;
|
||||
|
||||
// keep the merge fronts into registers
|
||||
int32_t idx0 = 0;
|
||||
float val0 = 0.0f;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
|
||||
int32_t idx1 = 0;
|
||||
float val1 = 0.0f;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
|
||||
for (int k = k0; k < k1; ++k) {
|
||||
int32_t out_idx;
|
||||
|
||||
if (i >= len0) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp1[j++];
|
||||
}
|
||||
break;
|
||||
} else if (j >= len1) {
|
||||
while (k < k1) {
|
||||
dst[k++] = tmp0[i++];
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
bool take_left;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
take_left = (val0 <= val1);
|
||||
} else {
|
||||
take_left = (val0 >= val1);
|
||||
}
|
||||
|
||||
if (take_left) {
|
||||
out_idx = idx0;
|
||||
++i;
|
||||
if (i < len0) {
|
||||
idx0 = tmp0[i];
|
||||
val0 = src0_row[idx0];
|
||||
}
|
||||
} else {
|
||||
out_idx = idx1;
|
||||
++j;
|
||||
if (j < len1) {
|
||||
idx1 = tmp1[j];
|
||||
val1 = src0_row[idx1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[k] = out_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
@@ -0,0 +1,226 @@
|
||||
#include "common.h"
|
||||
|
||||
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
|
||||
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
|
||||
|
||||
template <typename T0, typename T1, typename T>
|
||||
kernel void kernel_bin_fuse_impl(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_bin_op
|
||||
#define FC_F FC_bin_f
|
||||
#define FC_RB FC_bin_rb
|
||||
#define FC_CB FC_bin_cb
|
||||
|
||||
if (FC_RB) {
|
||||
// row broadcast
|
||||
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
|
||||
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
|
||||
|
||||
device const T0 * src0_row = (device const T0 *) (src0);
|
||||
device T * dst_row = (device T *) (dst);
|
||||
|
||||
if (FC_F == 1) {
|
||||
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
|
||||
|
||||
if (FC_OP == 0) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
dst_row[i0] = src0_row[i0] - src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
dst_row[i0] = src0_row[i0] * src1_row[i1];
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
dst_row[i0] = src0_row[i0] / src1_row[i1];
|
||||
}
|
||||
} else {
|
||||
T0 res = src0_row[i0];
|
||||
|
||||
if (FC_OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
|
||||
}
|
||||
}
|
||||
|
||||
dst_row[i0] = res;
|
||||
}
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i13 = i03%args.ne13;
|
||||
const int i12 = i02%args.ne12;
|
||||
const int i11 = i01%args.ne11;
|
||||
|
||||
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
|
||||
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
|
||||
|
||||
if (FC_F == 1) {
|
||||
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
if (FC_OP == 0) {
|
||||
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const T1 * src1_ptr[8];
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
|
||||
}
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i10 = FC_CB ? i0%args.ne10 : i0;
|
||||
|
||||
T res = src0_ptr[i0];
|
||||
|
||||
if (FC_OP == 0) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res += src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 1) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res -= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 2) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res *= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_OP == 3) {
|
||||
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
|
||||
res /= src1_ptr[j][i10];
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_F
|
||||
#undef FC_RB
|
||||
#undef FC_CB
|
||||
}
|
||||
|
||||
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
|
||||
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
|
||||
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
|
||||
|
||||
kernel void kernel_add_id(
|
||||
constant ggml_metal_kargs_add_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i1 = tgpig.x;
|
||||
const int i2 = tgpig.y;
|
||||
|
||||
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
|
||||
|
||||
const size_t nb1 = args.ne0 * sizeof(float);
|
||||
const size_t nb2 = args.ne1 * nb1;
|
||||
|
||||
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
|
||||
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
|
||||
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_row[i0] = src0_row[i0] + src1_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_repeat(
|
||||
constant ggml_metal_kargs_repeat & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
const int i03 = i3%args.ne03;
|
||||
const int i02 = i2%args.ne02;
|
||||
const int i01 = i1%args.ne01;
|
||||
|
||||
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
|
||||
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int i00 = i0%args.ne00;
|
||||
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
||||
|
||||
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
||||
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
||||
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
||||
@@ -0,0 +1,126 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml-metal-impl.h"
|
||||
|
||||
#include <metal_stdlib>
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
#include <metal_tensor>
|
||||
|
||||
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
|
||||
#endif
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
||||
|
||||
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
||||
|
||||
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
|
||||
|
||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
//
|
||||
// cmd:
|
||||
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
|
||||
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
|
||||
//
|
||||
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
|
||||
#undef GGML_METAL_HAS_BF16
|
||||
#endif
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
typedef matrix<bfloat, 4, 4> bfloat4x4;
|
||||
typedef matrix<bfloat, 2, 4> bfloat2x4;
|
||||
#endif
|
||||
|
||||
constexpr constant static float kvalues_iq4nl_f[16] = {
|
||||
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
||||
};
|
||||
|
||||
constexpr constant static float kvalues_mxfp4_f[16] = {
|
||||
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
|
||||
};
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
int ml = 0, mu = n-1;
|
||||
while (mu-ml > 1) {
|
||||
int mav = (ml+mu)/2;
|
||||
if (x < val[mav]) mu = mav; else ml = mav;
|
||||
}
|
||||
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uint8_t x) {
|
||||
uint32_t bits;
|
||||
|
||||
if (x == 0) {
|
||||
bits = 0x00400000;
|
||||
} else {
|
||||
bits = (uint32_t) x << 23;
|
||||
}
|
||||
|
||||
return as_type<float>(bits);
|
||||
}
|
||||
|
||||
static inline float dot(float x, float y) {
|
||||
return x*y;
|
||||
}
|
||||
|
||||
static inline float sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float sum(float4 x) {
|
||||
return x[0] + x[1] + x[2] + x[3];
|
||||
}
|
||||
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_ASC,
|
||||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
inline T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
template<typename T> T elu_approx(T x);
|
||||
|
||||
template<> inline float elu_approx<float>(float x) {
|
||||
return (x > 0.f) ? x : (exp(x) - 1);
|
||||
}
|
||||
|
||||
template<> inline float4 elu_approx<float4>(float4 x) {
|
||||
float4 res;
|
||||
|
||||
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
#include "common.h"
|
||||
|
||||
typedef void (im2col_t)(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// const int64_t IC = tgpg[0];
|
||||
const int64_t OH = tgpg[1];
|
||||
const int64_t OW = tgpg[2];
|
||||
|
||||
const int64_t KH = ntg[1];
|
||||
const int64_t KW = ntg[2];
|
||||
|
||||
int64_t in = tpitg[0];
|
||||
const int64_t ikh = tpitg[1];
|
||||
const int64_t ikw = tpitg[2];
|
||||
|
||||
const int64_t iic = tgpig[0];
|
||||
const int64_t ioh = tgpig[1];
|
||||
const int64_t iow = tgpig[2];
|
||||
|
||||
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
||||
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
||||
|
||||
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
||||
while (in < args.N) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
offset_dst += ntg[0]*args.CHW*OH*OW;
|
||||
|
||||
in += ntg[0];
|
||||
}
|
||||
} else {
|
||||
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
||||
|
||||
while (in < args.N) {
|
||||
pdst[offset_dst] = x[offset_src];
|
||||
|
||||
offset_dst += ntg[0]*args.CHW*OH*OW;
|
||||
offset_src += ntg[0]*args.ofs0;
|
||||
|
||||
in += ntg[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
||||
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
||||
|
||||
// TODO: optimize
|
||||
typedef void (im2col_ext_t)(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_im2col_ext(
|
||||
constant ggml_metal_kargs_im2col & args,
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
||||
const int64_t KHW = (int64_t)args.KHW;
|
||||
|
||||
const int64_t d = tgpig[0] / args.CHW;
|
||||
const int64_t chw = tgpig[0] % args.CHW;
|
||||
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
||||
const int64_t HW = tgpig[0] % KHW;
|
||||
|
||||
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
||||
if (tpitg_0 >= args.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t tpitg_1 = HW / args.KW;
|
||||
const int64_t tpitg_2 = HW % args.KW;
|
||||
|
||||
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
||||
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
||||
|
||||
const int64_t offset_dst =
|
||||
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
||||
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
||||
|
||||
device T * pdst = (device T *) (dst);
|
||||
|
||||
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
||||
pdst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
||||
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
||||
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
||||
|
||||
template <typename TK>
|
||||
kernel void kernel_conv_2d(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
|
||||
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
|
||||
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
|
||||
const uint thread_index = tg_index * threads_per_tg + local_thread;
|
||||
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
|
||||
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
|
||||
|
||||
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
|
||||
uint64_t tmp = index;
|
||||
|
||||
const int32_t ow = tmp % args.OW; tmp /= args.OW;
|
||||
const int32_t oh = tmp % args.OH; tmp /= args.OH;
|
||||
const int32_t oc = tmp % args.OC; tmp /= args.OC;
|
||||
const int32_t n = tmp;
|
||||
|
||||
float acc = 0.0f;
|
||||
|
||||
const int32_t base_x = ow*args.s0 - args.p0;
|
||||
const int32_t base_y = oh*args.s1 - args.p1;
|
||||
|
||||
int32_t ky_start = 0;
|
||||
if (base_y < 0) {
|
||||
ky_start = (-base_y + args.d1 - 1)/args.d1;
|
||||
}
|
||||
int32_t ky_end = args.KH;
|
||||
const int32_t y_max = args.IH - 1 - base_y;
|
||||
if (y_max < 0) {
|
||||
ky_end = ky_start;
|
||||
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
|
||||
ky_end = min(ky_end, y_max/args.d1 + 1);
|
||||
}
|
||||
|
||||
int32_t kx_start = 0;
|
||||
if (base_x < 0) {
|
||||
kx_start = (-base_x + args.d0 - 1)/args.d0;
|
||||
}
|
||||
int32_t kx_end = args.KW;
|
||||
const int32_t x_max = args.IW - 1 - base_x;
|
||||
if (x_max < 0) {
|
||||
kx_end = kx_start;
|
||||
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
|
||||
kx_end = min(kx_end, x_max/args.d0 + 1);
|
||||
}
|
||||
|
||||
if (ky_start < ky_end && kx_start < kx_end) {
|
||||
const uint64_t src_base_n = (uint64_t) n * args.nb13;
|
||||
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
|
||||
|
||||
for (int32_t ic = 0; ic < args.IC; ++ic) {
|
||||
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
|
||||
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
|
||||
|
||||
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
|
||||
const int32_t iy = base_y + ky*args.d1;
|
||||
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
|
||||
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
|
||||
|
||||
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
|
||||
const int32_t ix = base_x + kx*args.d0;
|
||||
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
|
||||
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
|
||||
|
||||
const float x = *(device const float *)(src + src_offs);
|
||||
const float w = (float) (*(device const TK *)(weights + w_offs));
|
||||
|
||||
acc += x * w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t dst_offs =
|
||||
(uint64_t) n * args.nb3 +
|
||||
(uint64_t) oc * args.nb2 +
|
||||
(uint64_t) oh * args.nb1 +
|
||||
(uint64_t) ow * args.nb0;
|
||||
|
||||
*(device float *)(dst + dst_offs) = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_2d_f32_f32")]]
|
||||
kernel void kernel_conv_2d<float>(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_2d_f16_f32")]]
|
||||
kernel void kernel_conv_2d<half>(
|
||||
constant ggml_metal_kargs_conv_2d & args,
|
||||
device const char * weights,
|
||||
device const char * src,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
typedef void (conv_transpose_1d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_1d(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
|
||||
// For output position j on the time axis, only input positions
|
||||
// i such that i*s0 <= j < i*s0 + K
|
||||
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
|
||||
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
|
||||
// (typically 2 for stride==K/2 transposed convs).
|
||||
const int32_t j = tgpig[0];
|
||||
const int32_t s0 = args.s0;
|
||||
const int32_t K = args.K;
|
||||
const int32_t IL = args.IL;
|
||||
|
||||
int32_t i_min;
|
||||
{
|
||||
int32_t a = j - K + 1;
|
||||
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
|
||||
}
|
||||
int32_t i_max = j / s0;
|
||||
if (i_max > IL - 1) i_max = IL - 1;
|
||||
|
||||
float v = 0.0f;
|
||||
if (i_min <= i_max) {
|
||||
for (int64_t c = 0; c < args.IC; c++) {
|
||||
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
|
||||
const int32_t input_offset = c * IL;
|
||||
|
||||
for (int32_t i = i_min; i <= i_max; i++) {
|
||||
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
|
||||
|
||||
dst_ptr[0] = v;
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_1d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_1d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
|
||||
typedef void (conv_transpose_2d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_2d(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t out_x = tgpig[0];
|
||||
const int64_t out_y = tgpig[1];
|
||||
const int64_t out_c = tgpig[2];
|
||||
|
||||
const int64_t kw = tpitg[0];
|
||||
const int64_t kh = tpitg[1];
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||
int64_t in_y = out_y - kh;
|
||||
|
||||
if (in_y < 0 || in_y % args.s0) continue;
|
||||
|
||||
in_y /= args.s0;
|
||||
|
||||
if (in_y >= args.IH) continue;
|
||||
|
||||
int64_t in_x = out_x - kw;
|
||||
|
||||
if (in_x < 0 || in_x % args.s0) continue;
|
||||
|
||||
in_x /= args.s0;
|
||||
|
||||
if (in_x >= args.IW) continue;
|
||||
|
||||
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||
|
||||
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||
}
|
||||
|
||||
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||
shared_sum[tid] = v;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tid == 0) {
|
||||
float total = 0.0f;
|
||||
const uint num_threads = ntg.x * ntg.y;
|
||||
for (uint i = 0; i < num_threads; i++) {
|
||||
total += shared_sum[i];
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||
dst_ptr[0] = total;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_3d(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0, // Weights [IC * OC, KD, KH, KW]
|
||||
device const char * src1, // Inputs [IC * N, ID, IH, IW]
|
||||
device char * dst, // Outputs [OC * N, OD, OH, OW]
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
|
||||
// 1. Un-flatten the spatial dimension from Grid X
|
||||
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
|
||||
|
||||
if (spatial_idx >= args.OW * args.OH * args.OD) {
|
||||
return; // Thread falls outside the spatial volume
|
||||
}
|
||||
|
||||
int64_t od = spatial_idx / (args.OW * args.OH);
|
||||
int64_t oh = (spatial_idx / args.OW) % args.OH;
|
||||
int64_t ow = spatial_idx % args.OW;
|
||||
|
||||
// 2. Map Y to Channels, Z to Batch
|
||||
int64_t oc = tgpig.y;
|
||||
int64_t batch_idx = tgpig.z;
|
||||
|
||||
// 3. Calculate anchor coordinates in the Input volume
|
||||
int64_t i_w_base = ow * args.s0 - args.p0;
|
||||
int64_t i_h_base = oh * args.s1 - args.p1;
|
||||
int64_t i_d_base = od * args.s2 - args.p2;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
|
||||
for (int64_t ic = 0; ic < args.IC; ++ic) {
|
||||
|
||||
// ggml packs batch and channel together in the 4th dimension
|
||||
int64_t src_cn_idx = batch_idx * args.IC + ic;
|
||||
int64_t w_cn_idx = oc * args.IC + ic;
|
||||
|
||||
for (int64_t kz = 0; kz < args.KD; ++kz) {
|
||||
int64_t id = i_d_base + kz * args.d2;
|
||||
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
|
||||
|
||||
for (int64_t ky = 0; ky < args.KH; ++ky) {
|
||||
int64_t ih = i_h_base + ky * args.d1;
|
||||
if (ih < 0 || ih >= args.IH) continue;
|
||||
|
||||
for (int64_t kx = 0; kx < args.KW; ++kx) {
|
||||
int64_t iw = i_w_base + kx * args.d0;
|
||||
if (iw < 0 || iw >= args.IW) continue;
|
||||
|
||||
// Convert multi-dimensional coordinates to flat byte offsets
|
||||
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
|
||||
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
|
||||
|
||||
// Dereference memory and cast weights to f32 if they were f16
|
||||
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
|
||||
float i_val = *(device const float*)((device const char*)src1 + i_idx);
|
||||
|
||||
sum += w_val * i_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Write the accumulated value out to RAM
|
||||
int64_t dst_cn_idx = batch_idx * args.OC + oc;
|
||||
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
|
||||
|
||||
*(device float*)(dst + d_idx) = sum;
|
||||
}
|
||||
|
||||
// Explicit instantiations so the JIT compiler can find them by name
|
||||
template [[host_name("kernel_conv_3d_f32_f32")]]
|
||||
kernel void kernel_conv_3d<float>(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
|
||||
// Explicit instantiation for f16 weights
|
||||
template [[host_name("kernel_conv_3d_f16_f32")]]
|
||||
kernel void kernel_conv_3d<half>(
|
||||
constant ggml_metal_kargs_conv_3d & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||
@@ -0,0 +1,686 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#define GGML_COMMON_DECL_METAL
|
||||
#define GGML_COMMON_IMPL_METAL
|
||||
#if defined(GGML_METAL_EMBED_LIBRARY)
|
||||
__embed_ggml-common.h__
|
||||
#else
|
||||
#include "ggml-common.h"
|
||||
#endif
|
||||
|
||||
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template <typename type4x4>
|
||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs;
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
|
||||
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
||||
const uint8_t b0 = qs[byte_offset];
|
||||
const uint8_t b1 = qs[byte_offset + 1];
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
||||
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
||||
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
||||
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
||||
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
||||
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
||||
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
||||
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
||||
|
||||
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
||||
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
||||
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
||||
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
||||
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
||||
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
||||
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
||||
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float neg_d = -d;
|
||||
const int base = il * 4;
|
||||
const uint8_t byte = xb->qs[base / 8];
|
||||
const int s = base % 8;
|
||||
|
||||
float4 reg_f;
|
||||
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
||||
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
||||
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
||||
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
||||
|
||||
reg = (type4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float md = -8.h * xb->d;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float md = -8.h * xb->d;
|
||||
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
|
||||
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float m = xb->m;
|
||||
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
const float m = xb->m;
|
||||
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||
const ushort mask1 = mask0 << 8;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
|
||||
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||
const float d = xb->d;
|
||||
const float md = -16.h * xb->d;
|
||||
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = il ? 4 : 0;
|
||||
|
||||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||
const float d = xb->d;
|
||||
const float md = -16.h * xb->d;
|
||||
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = (il/4) ? 4 : 0;
|
||||
|
||||
const int gh_mv = (il/4) ? 12 : 0;
|
||||
const int gh_bk = (il/4) ? 0 : 4;
|
||||
|
||||
for (int ii = 0; ii < 2; ii++) {
|
||||
int i = 2*(il%4) + ii;
|
||||
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[2*ii + 0] = d * x0 + md;
|
||||
reg[2*ii + 1] = d * x1 + md;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||
const float d = xb->d;
|
||||
const float m = xb->m;
|
||||
const ushort mask = il ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = il ? 4 : 0;
|
||||
|
||||
const int gh_mv = il ? 12 : 0;
|
||||
const int gh_bk = il ? 0 : 4;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
|
||||
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||
const float d = xb->d;
|
||||
const float m = xb->m;
|
||||
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
|
||||
|
||||
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
||||
|
||||
const int x_mv = (il/4) ? 4 : 0;
|
||||
|
||||
const int gh_mv = (il/4) ? 12 : 0;
|
||||
const int gh_bk = (il/4) ? 0 : 4;
|
||||
|
||||
for (int ii = 0; ii < 2; ii++) {
|
||||
int i = 2*(il%4) + ii;
|
||||
|
||||
// extract the 5-th bits for x0 and x1
|
||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||
|
||||
// combine the 4-bits from qs with the 5th bit
|
||||
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||
|
||||
reg[2*ii + 0] = d * x0 + m;
|
||||
reg[2*ii + 1] = d * x1 + m;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const float d = xb->d;
|
||||
|
||||
float4x4 reg_f;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
|
||||
}
|
||||
|
||||
reg = (type4x4) reg_f;
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||
const float d = xb->d;
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||
|
||||
const float d = e8m0_to_fp32(xb->e);
|
||||
const uint8_t shr = il >= 1 ? 4 : 0;
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
|
||||
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
|
||||
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
|
||||
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
|
||||
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
|
||||
|
||||
const float d = e8m0_to_fp32(xb->e);
|
||||
const short il4 = il%4;
|
||||
|
||||
const uint8_t shr = il >= 4 ? 4 : 0;
|
||||
|
||||
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
|
||||
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
|
||||
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
|
||||
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
const float min = xb->dmin;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
float dl, ml;
|
||||
uint8_t sc = xb->scales[il];
|
||||
|
||||
q = q + 32*(il/8) + 16*(il&1);
|
||||
il = (il/2)%4;
|
||||
|
||||
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d_all = xb->d;
|
||||
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
||||
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
|
||||
q = q + 32 * (il/8) + 16 * (il&1);
|
||||
h = h + 16 * (il&1);
|
||||
uint8_t m = 1 << (il/2);
|
||||
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
||||
((il/4)>0 ? 12 : 3);
|
||||
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
||||
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
||||
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
|
||||
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
||||
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
||||
const float ml = 4.f * dl;
|
||||
|
||||
il = (il/2) & 3;
|
||||
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
||||
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
||||
dl *= coef;
|
||||
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||
}
|
||||
}
|
||||
|
||||
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
|
||||
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
|
||||
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
|
||||
device const uchar * q = xb->qs;
|
||||
|
||||
short is = (il/4) * 2;
|
||||
q = q + (il/4) * 32 + 16 * (il&1);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.h;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il < 2 ? 0x0F : 0xF0;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * q = xb->qs;
|
||||
device const uint8_t * qh = xb->qh;
|
||||
|
||||
short is = (il/4) * 2;
|
||||
q = q + 32 * (il/4) + 16 * (il&1);
|
||||
qh = qh + 16 * (il&1);
|
||||
uint8_t ul = 1 << (il/2);
|
||||
il = il & 3;
|
||||
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
|
||||
const float d = il < 2 ? xb->d : xb->d / 16.f;
|
||||
const float min = xb->dmin;
|
||||
const float dl = d * sc[0];
|
||||
const float ml = min * sc[1];
|
||||
|
||||
const ushort mask = il<2 ? 0x0F : 0xF0;
|
||||
const float qh_val = il<2 ? 16.f : 256.f;
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
||||
const half d_all = xb->d;
|
||||
device const uint16_t * ql = (device const uint16_t *)xb->ql;
|
||||
device const uint16_t * qh = (device const uint16_t *)xb->qh;
|
||||
device const int8_t * scales = (device const int8_t *)xb->scales;
|
||||
|
||||
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
|
||||
qh = qh + 16*(il/8) + 8*(il&1);
|
||||
float sc = scales[(il%2) + 2 * ((il/2))];
|
||||
il = (il/2) & 3;
|
||||
|
||||
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
|
||||
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
|
||||
const float ml = d_all * sc * 32.f;
|
||||
const float dl0 = d_all * sc;
|
||||
const float dl1 = dl0 / 256.f;
|
||||
const float dl2 = dl0 / (256.f * 256.f);
|
||||
const float dl3 = dl0 / (256.f * 256.f * 256.f);
|
||||
const uint8_t shr_h = il>2 ? 2 : 0;
|
||||
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
|
||||
const uint8_t shr_l = il>1 ? 4 : 0;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
|
||||
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
|
||||
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
|
||||
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
|
||||
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
|
||||
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
|
||||
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
|
||||
device const uint16_t * q2 = xb->qs + 4*ib32;
|
||||
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
|
||||
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
|
||||
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
|
||||
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
|
||||
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
|
||||
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
|
||||
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint16_t * q2 = xb->qs + 4*ib32;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
||||
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
|
||||
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
|
||||
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * q3 = xb->qs + 8*ib32;
|
||||
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
|
||||
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
}
|
||||
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
|
||||
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
|
||||
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 8*ib32;
|
||||
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
||||
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
||||
}
|
||||
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
||||
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
||||
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const float d = xb->d;
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * signs = qs + QK_K/8;
|
||||
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
||||
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
||||
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
const float d = xb->d;
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint16_t * qh = xb->qh;
|
||||
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
|
||||
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
|
||||
const uint16_t h = qh[ib32] >> 6*il;
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) + ml;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) + ml;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
const float d = scale.f16;
|
||||
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
||||
|
||||
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
||||
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||
const float d = xb->d;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
|
||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||
const float d = xb->d;
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
|
||||
reg[0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
||||
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
||||
const float d = (float)xb->d * (ls - 32);
|
||||
uint32_t aux32;
|
||||
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
||||
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
||||
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
||||
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
||||
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,250 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
|
||||
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
|
||||
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
|
||||
|
||||
#if 1
|
||||
template<short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
#define K FC_gated_delta_net_K
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B (n_seqs)
|
||||
const uint i21 = tgpig.y; // H (head)
|
||||
const uint i20 = tgpig.x*NSG + ty; // row within S_v
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
device const float * s_ptr = (device const float *) (s) + state_in_base;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] = s_ptr[is];
|
||||
}
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
|
||||
// output state base offset: after attention scores
|
||||
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
|
||||
// output state per-slot size: S_v * S_v * H * n_seqs
|
||||
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
|
||||
// per-(seq,head) offset within a slot
|
||||
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
float s_k = 0.0f;
|
||||
|
||||
if (G == 1) {
|
||||
const float g_exp = exp(g_ptr[0]);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= g_exp;
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
} else {
|
||||
// KDA
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] *= exp(g_ptr[is]);
|
||||
|
||||
s_k += ls[j]*k_ptr[is];
|
||||
}
|
||||
}
|
||||
|
||||
s_k = simd_sum(s_k);
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
float y = 0.0f;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] += k_ptr[is]*d;
|
||||
|
||||
y += ls[j]*q_ptr[is];
|
||||
}
|
||||
|
||||
y = simd_sum(y);
|
||||
|
||||
if (tx == 0) {
|
||||
dst_attn[t*args.ne21*S_v] = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
if (K > 1) {
|
||||
const int target_slot = (int)args.ne22 - 1 - (int)t;
|
||||
if (target_slot >= 0 && target_slot < (int)K) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (K == 1) {
|
||||
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
#undef K
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
|
||||
|
||||
#else
|
||||
// a simplified version of the above
|
||||
// no performance improvement, so keep the above version for now
|
||||
|
||||
template<typename T, short NSG>
|
||||
kernel void kernel_gated_delta_net_impl(
|
||||
constant ggml_metal_kargs_gated_delta_net & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * g,
|
||||
device const char * b,
|
||||
device const char * s,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
#define S_v FC_gated_delta_net_ne20
|
||||
#define G FC_gated_delta_net_ne30
|
||||
|
||||
const uint tx = tpitg.x;
|
||||
const uint ty = tpitg.y;
|
||||
|
||||
const uint i23 = tgpig.z; // B
|
||||
const uint i21 = tgpig.y; // H
|
||||
const uint i20 = tgpig.x*NSG + ty;
|
||||
|
||||
const uint i01 = i21 % args.ne01;
|
||||
const uint i11 = i21 % args.ne11;
|
||||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
|
||||
float lsf[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
lsf[j] = s_ptr[is*S_v];
|
||||
}
|
||||
|
||||
thread T * ls = (thread T *) (lsf);
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
||||
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
|
||||
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
|
||||
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
|
||||
|
||||
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
|
||||
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
|
||||
|
||||
for (short t = 0; t < args.ne22; t++) {
|
||||
device const T * qt_ptr = (device const T *) (q_ptr);
|
||||
device const T * kt_ptr = (device const T *) (k_ptr);
|
||||
device const T * gt_ptr = (device const T *) (g_ptr);
|
||||
|
||||
if (G == 1) {
|
||||
*ls *= exp(g_ptr[0]);
|
||||
} else {
|
||||
// KDA
|
||||
*ls *= exp(gt_ptr[tx]);
|
||||
}
|
||||
|
||||
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
|
||||
|
||||
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
|
||||
|
||||
*ls += kt_ptr[tx]*d;
|
||||
|
||||
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
|
||||
|
||||
if (tx == 0) {
|
||||
*dst_attn = y*scale;
|
||||
}
|
||||
|
||||
q_ptr += args.ns02;
|
||||
k_ptr += args.ns12;
|
||||
v_ptr += args.ns22;
|
||||
|
||||
b_ptr += args.ne21;
|
||||
g_ptr += args.ne21*G;
|
||||
|
||||
dst_attn += args.ne21*S_v;
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
device T * dstt_state = (device T *) (dst_state);
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is*S_v] = lsf[j];
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
#undef G
|
||||
}
|
||||
|
||||
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
|
||||
|
||||
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
|
||||
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
|
||||
#endif
|
||||
@@ -0,0 +1,347 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_argmax_f32(
|
||||
constant ggml_metal_kargs_argmax & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
|
||||
|
||||
float lmax = -INFINITY;
|
||||
int32_t larg = -1;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
if (x_row[i00] > lmax) {
|
||||
lmax = x_row[i00];
|
||||
larg = i00;
|
||||
}
|
||||
}
|
||||
|
||||
// find the argmax value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
|
||||
|
||||
device int32_t * dst_i32 = (device int32_t *) dst;
|
||||
|
||||
threadgroup float * shared_maxval = (threadgroup float *) shmem;
|
||||
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
shared_maxval[tiisg] = -INFINITY;
|
||||
shared_argmax[tiisg] = -1;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_maxval[sgitg] = max_val;
|
||||
shared_argmax[sgitg] = arg_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = shared_maxval[tiisg];
|
||||
arg_val = shared_argmax[tiisg];
|
||||
|
||||
float max_val_reduced = simd_max(max_val);
|
||||
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
|
||||
|
||||
dst_i32[tgpig] = arg_val_reduced;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
dst_i32[tgpig] = arg_val;
|
||||
}
|
||||
|
||||
kernel void kernel_diag_f32(
|
||||
constant ggml_metal_kargs_diag & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
|
||||
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_roll_f32(
|
||||
constant ggml_metal_kargs_roll & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
device const float * src0_ptr = (device const float *) src0;
|
||||
device float * dst_ptr = (device float *) dst;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
// apply shifts and wrap around
|
||||
int64_t i00 = i0 - args.s0;
|
||||
int64_t i01 = i1 - args.s1;
|
||||
int64_t i02 = i2 - args.s2;
|
||||
int64_t i03 = i3 - args.s3;
|
||||
|
||||
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
|
||||
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
|
||||
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
|
||||
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
|
||||
|
||||
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
|
||||
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
|
||||
|
||||
dst_ptr[dst_idx] = src0_ptr[src_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_pad_impl(
|
||||
constant ggml_metal_kargs_pad & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i3 = tgpig.z;
|
||||
const int32_t i2 = tgpig.y;
|
||||
const int32_t k0 = tgpig.x/args.ne1;
|
||||
const int32_t i1 = tgpig.x - k0*args.ne1;
|
||||
|
||||
const int32_t i03 = i3;
|
||||
const int32_t i02 = i2;
|
||||
const int32_t i01 = i1;
|
||||
|
||||
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
|
||||
const int32_t i0 = k0*1024 + tpitg.x + l0;
|
||||
if (i0 >= args.ne0) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
dst_ptr[i0] = src0_ptr[i0];
|
||||
} else {
|
||||
dst_ptr[i0] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
|
||||
|
||||
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
|
||||
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
|
||||
|
||||
// TODO: this is slow - optimize
|
||||
kernel void kernel_pad_reflect_1d_f32(
|
||||
constant ggml_metal_kargs_pad_reflect_1d & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3;
|
||||
const int64_t i02 = i2;
|
||||
const int64_t i01 = i1;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
if (i0 < args.p0) {
|
||||
dst_ptr[i0] = src0_ptr[args.p0 - i0];
|
||||
} else if (i0 < args.ne0 - args.p1) {
|
||||
dst_ptr[i0] = src0_ptr[i0 - args.p0];
|
||||
} else {
|
||||
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_arange_f32(
|
||||
constant ggml_metal_kargs_arange & args,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
device float * dst_ptr = (device float *) dst;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
dst_ptr[i0] = args.start + args.step * i0;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_timestep_embedding_f32(
|
||||
constant ggml_metal_kargs_timestep_embedding & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
int i = tgpig.x;
|
||||
device float * embed_data = (device float *)(dst + i*args.nb1);
|
||||
|
||||
int half_ = args.dim / 2;
|
||||
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
||||
float timestep = ((device float *)src0)[i];
|
||||
float freq = (float)exp(-log((float)args.max_period) * j / half_);
|
||||
float arg = timestep * freq;
|
||||
embed_data[j ] = cos(arg);
|
||||
embed_data[j + half_] = sin(arg);
|
||||
}
|
||||
|
||||
if (args.dim % 2 != 0 && tpitg.x == 0) {
|
||||
embed_data[2 * half_] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_adamw_f32(
|
||||
constant ggml_metal_kargs_opt_step_adamw & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device float * g_m,
|
||||
device float * g_v,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = pars[0];
|
||||
const float beta1 = pars[1];
|
||||
const float beta2 = pars[2];
|
||||
const float eps = pars[3];
|
||||
const float wd = pars[4];
|
||||
const float beta1h = pars[5];
|
||||
const float beta2h = pars[6];
|
||||
|
||||
const float gi = g[gid];
|
||||
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
|
||||
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
|
||||
|
||||
g_m[gid] = gmi;
|
||||
g_v[gid] = gvi;
|
||||
|
||||
const float mh = gmi * beta1h;
|
||||
const float vh = sqrt(gvi * beta2h) + eps;
|
||||
|
||||
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
|
||||
}
|
||||
|
||||
kernel void kernel_opt_step_sgd_f32(
|
||||
constant ggml_metal_kargs_opt_step_sgd & args,
|
||||
device float * x,
|
||||
device const float * g,
|
||||
device const float * pars,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_memset(
|
||||
constant ggml_metal_kargs_memset & args,
|
||||
device T * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
|
||||
|
||||
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
|
||||
|
||||
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_count_equal(
|
||||
constant ggml_metal_kargs_count_equal & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device atomic_int * dst,
|
||||
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const short NSG = FC_count_equal_nsg;
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int sum = 0;
|
||||
|
||||
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
|
||||
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
const T v0 = *(device const T *)(base0 + i0*args.nb00);
|
||||
const T v1 = *(device const T *)(base1 + i0*args.nb10);
|
||||
sum += (v0 == v1);
|
||||
}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_i32[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0.0f;
|
||||
if (tpitg.x < NSG) {
|
||||
v = shmem_i32[tpitg.x];
|
||||
}
|
||||
|
||||
float total = simd_sum(v);
|
||||
if (tpitg.x == 0) {
|
||||
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
|
||||
|
||||
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
|
||||
@@ -0,0 +1,838 @@
|
||||
#include "common.h"
|
||||
#include "dequantize.h"
|
||||
|
||||
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
|
||||
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
|
||||
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
|
||||
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
|
||||
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
|
||||
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
template<
|
||||
typename SA, typename SA_4x4, typename SA_8x8,
|
||||
typename SB, typename SB_2x4, typename SB_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * srcA,
|
||||
device const char * srcB,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig [[threadgroup_position_in_grid]],
|
||||
ushort tiitg [[thread_index_in_threadgroup]],
|
||||
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
|
||||
(void) sgitg;
|
||||
|
||||
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
|
||||
const int K = args.ne00;
|
||||
const int M = args.ne0;
|
||||
const int N = args.ne1;
|
||||
|
||||
// Batch dimension handling
|
||||
const int im = tgpig.z;
|
||||
const int i12 = im % FC_mul_mm_ne12;
|
||||
const int i13 = im / FC_mul_mm_ne12;
|
||||
|
||||
// Batch offsets for srcA and srcB
|
||||
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
||||
|
||||
// Tile dimensions
|
||||
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
|
||||
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// Tile offsets in output matrix
|
||||
const int ra = tgpig.y * NRA;
|
||||
const int rb = tgpig.x * NRB;
|
||||
|
||||
// Threadgroup memory for dequantized A tile only
|
||||
threadgroup SA * sa = (threadgroup SA *)(shmem);
|
||||
|
||||
// Work-item count for A loading
|
||||
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
|
||||
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
|
||||
|
||||
// tA wraps threadgroup memory
|
||||
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
|
||||
|
||||
// tB wraps device memory directly
|
||||
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
|
||||
const int strideB = args.nb11 / sizeof(T1);
|
||||
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
|
||||
|
||||
// Configure matmul operation
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(
|
||||
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
|
||||
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
|
||||
|
||||
// Accumulate partial results over K dimension
|
||||
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
|
||||
// === PHASE 1: Dequantization of A into threadgroup memory ===
|
||||
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
|
||||
const int row = work / N_MM_NK;
|
||||
const int k_chunk = work % N_MM_NK;
|
||||
const int k_pos = loop_k + k_chunk * 16;
|
||||
const short k_base = k_chunk * 16;
|
||||
|
||||
// Bounds check: skip device read if row is out of matrix bounds
|
||||
if (ra + row < M) {
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
|
||||
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
|
||||
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
|
||||
// Mirrors the legacy kernel's existing guard.
|
||||
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
|
||||
}
|
||||
} else {
|
||||
const int block_idx = k_pos / (16 * nl);
|
||||
const short il = (k_pos / 16) % nl;
|
||||
|
||||
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
|
||||
|
||||
SA_4x4 temp_a;
|
||||
dequantize_func(row_ptr + block_idx, il, temp_a);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Zero-pad rows beyond matrix bounds
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// === PHASE 2: Tensor matmul ===
|
||||
auto mA = tA.slice(0, 0);
|
||||
auto mB = tB.slice(loop_k, rb);
|
||||
|
||||
mm.run(mB, mA, cT);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// Store result tile to output matrix (with batch offset)
|
||||
// cT.store handles bounds checking via tD's extents (M, N)
|
||||
device float * dstBatch = (device float *)dst + im * N * M;
|
||||
|
||||
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
|
||||
cT.store(tD.slice(ra, rb));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template<
|
||||
typename S0, typename S0_4x4, typename S0_8x8,
|
||||
typename S1, typename S1_2x4, typename S1_8x8,
|
||||
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
|
||||
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm(
|
||||
constant ggml_metal_kargs_mul_mm & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
||||
constexpr int NK = 32;
|
||||
constexpr int NL0 = NK/16;
|
||||
constexpr int NL1 = NK/8;
|
||||
|
||||
const int im = tgpig.z;
|
||||
const int r0 = tgpig.y*NR0;
|
||||
const int r1 = tgpig.x*NR1;
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
||||
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
||||
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
||||
|
||||
const short il0 = (tiitg % NL0);
|
||||
|
||||
short il = il0;
|
||||
|
||||
const int i12 = im % FC_mul_mm_ne12;
|
||||
const int i13 = im / FC_mul_mm_ne12;
|
||||
|
||||
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
|
||||
const short offset1 = il0/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
||||
|
||||
const short iy = 8*(tiitg % NL1);
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*(r1 + lr1)
|
||||
+ args.nb10*iy);
|
||||
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
// NOTE: this is massively slower.. WTF?
|
||||
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
||||
y += NK;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
}
|
||||
|
||||
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
|
||||
// if no bounds checks on the output are needed, we can directly write to device memory
|
||||
device float * C = (device float *) dst +
|
||||
(r0 + 32*(sgitg & 1)) + \
|
||||
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
|
||||
}
|
||||
} else {
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
for (int j = tiitg; j < nr1; j += NR1) {
|
||||
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = temp_str + (j*NR0);
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = 0;
|
||||
for (; i < nr0/4; i++) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i *= 4;
|
||||
for (; i < nr0; i++) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_METAL_HAS_TENSOR
|
||||
|
||||
template<short ne20> // n_expert_used
|
||||
kernel void kernel_mul_mm_id_map0(
|
||||
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||
device const char * src2,
|
||||
device char * htpe,
|
||||
device char * hids,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
const short ide = tpitg; // expert id
|
||||
|
||||
uint32_t n_all = 0;
|
||||
|
||||
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
|
||||
|
||||
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
|
||||
if (i21 + tpitg < args.ne21) {
|
||||
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
|
||||
|
||||
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
|
||||
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sids[i20] = src2_i32[i20];
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short t = 0; t < ntg; t++) {
|
||||
if (i21 + t >= args.ne21) {
|
||||
break;
|
||||
}
|
||||
|
||||
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
|
||||
|
||||
short sel = 0;
|
||||
#pragma unroll(ne20)
|
||||
for (short i20 = 0; i20 < ne20; i20++) {
|
||||
sel += (sids[i20] == ide)*(i20 + 1);
|
||||
}
|
||||
|
||||
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
|
||||
|
||||
n_all += sel > 0;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
|
||||
tpe_u32[ide] = n_all;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
|
||||
|
||||
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
|
||||
kernel void kernel_mul_mm_id(
|
||||
constant ggml_metal_kargs_mul_mm_id & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * htpe,
|
||||
device const char * hids,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
|
||||
constexpr int NK = 32;
|
||||
constexpr int NL0 = NK/16;
|
||||
constexpr int NL1 = NK/8;
|
||||
|
||||
const int im = tgpig.z; // expert
|
||||
const int r0 = tgpig.y*NR0;
|
||||
const int r1 = tgpig.x*NR1;
|
||||
|
||||
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
|
||||
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||
|
||||
const int32_t neh1 = tpe_u32[im];
|
||||
|
||||
if (r1 >= neh1) {
|
||||
return;
|
||||
}
|
||||
|
||||
// if this block is of 64x32 shape or smaller
|
||||
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
|
||||
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
|
||||
|
||||
// a thread shouldn't load data outside of the matrix
|
||||
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
|
||||
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
|
||||
|
||||
const short il0 = (tiitg % NL0);
|
||||
|
||||
short il = il0;
|
||||
|
||||
const int id = ids_i32[im*args.ne21 + r1 + lr1];
|
||||
|
||||
const short i11 = (id % args.ne20) % args.ne11;
|
||||
const short i12 = (id / args.ne20);
|
||||
const short i13 = 0;
|
||||
|
||||
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
|
||||
const short offset1 = il0/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
|
||||
|
||||
const short iy = 8*(tiitg % NL1);
|
||||
|
||||
device const T1 * y = (device const T1 *)(src1
|
||||
+ args.nb13*i13
|
||||
+ args.nb12*i12
|
||||
+ args.nb11*i11
|
||||
+ args.nb10*iy);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
S0_8x8 ma[4];
|
||||
S1_8x8 mb[2];
|
||||
|
||||
simdgroup_float8x8 mc[8];
|
||||
|
||||
for (short i = 0; i < 8; i++){
|
||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||
}
|
||||
#else
|
||||
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
|
||||
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
|
||||
|
||||
mpp::tensor_ops::matmul2d<
|
||||
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
|
||||
execution_simdgroups<4>> mm;
|
||||
|
||||
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
|
||||
#endif
|
||||
|
||||
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
//const short lx = i%8;
|
||||
//const short ly = (tiitg/NL0)%8;
|
||||
const short lx = (tiitg/NL0)%8;
|
||||
const short ly = i%8;
|
||||
|
||||
const short ib = 8*sx + sy;
|
||||
|
||||
// NOTE: this is massively slower.. WTF?
|
||||
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
|
||||
|
||||
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
const short ib = 4*sx + sy;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#else
|
||||
// load data and store to threadgroup memory
|
||||
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// no need for dequantization
|
||||
for (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
|
||||
}
|
||||
} else {
|
||||
S0_4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 16; i++) {
|
||||
const short sx = 2*il0 + i/8;
|
||||
const short sy = (tiitg/NL0)/8;
|
||||
|
||||
const short lx = i%8;
|
||||
const short ly = (tiitg/NL0)%8;
|
||||
//const short lx = (tiitg/NL0)%8;
|
||||
//const short ly = i%8;
|
||||
|
||||
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
if (FC_mul_mm_bc_inp) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
|
||||
}
|
||||
} else {
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
//const short lx = i;
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
//const short lx = (tiitg/NL1)%8;
|
||||
//const short ly = i;
|
||||
|
||||
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
|
||||
}
|
||||
#endif
|
||||
|
||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||
|
||||
y += NK;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifndef GGML_METAL_HAS_TENSOR
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
|
||||
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
|
||||
|
||||
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 4; i++) {
|
||||
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 2; i++) {
|
||||
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
|
||||
}
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
FOR_UNROLL (short i = 0; i < 8; i++){
|
||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||
}
|
||||
|
||||
lsma += 8*64;
|
||||
lsmb += 4*64;
|
||||
}
|
||||
#else
|
||||
auto sA = tA.slice(0, 0);
|
||||
auto sB = tB.slice(0, 0);
|
||||
|
||||
mm.run(sB, sA, cT);
|
||||
#endif
|
||||
}
|
||||
|
||||
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
|
||||
cT.store(tC);
|
||||
#else
|
||||
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
|
||||
|
||||
for (short i = 0; i < 8; i++) {
|
||||
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
|
||||
}
|
||||
#endif
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short j = sgitg; j < nr1; j += 4) {
|
||||
const int id = ids_i32[im*args.ne21 + r1 + j];
|
||||
|
||||
const short ide = id % args.ne20;
|
||||
const short idt = id / args.ne20;
|
||||
|
||||
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
|
||||
device float4 * D4 = (device float4 *) D;
|
||||
|
||||
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
|
||||
threadgroup float4 * C4 = (threadgroup float4 *) C;
|
||||
|
||||
int i = tiisg;
|
||||
for (; i < nr0/4; i += 32) {
|
||||
*(D4 + i) = *(C4 + i);
|
||||
}
|
||||
|
||||
i = (4*(nr0/4)) + tiisg;
|
||||
for (; i < nr0; i += 32) {
|
||||
*(D + i) = *(C + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
|
||||
|
||||
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,308 @@
|
||||
#include "common.h"
|
||||
|
||||
// F == 1 : norm (no fuse)
|
||||
// F == 2 : norm + mul
|
||||
// F == 3 : norm + mul + add
|
||||
template <typename T, short F>
|
||||
kernel void kernel_norm_fuse_impl(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
T sumft(0.0f);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
sumft += x[i00];
|
||||
}
|
||||
sumf = dot(sumft, T(1.0f));
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float mean = sumf/args.ne00;
|
||||
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
sumf = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
y[i00] = x[i00] - mean;
|
||||
sumf += dot(y[i00], y[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float variance = sumf/args.ne00;
|
||||
|
||||
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
if (F == 1) {
|
||||
y[i00] = (y[i00]*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (y[i00]*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
|
||||
|
||||
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
|
||||
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
|
||||
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
|
||||
|
||||
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
|
||||
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
|
||||
|
||||
// F == 1 : rms_norm (no fuse)
|
||||
// F == 2 : rms_norm + mul
|
||||
// F == 3 : rms_norm + mul + add
|
||||
template <typename T, short F>
|
||||
kernel void kernel_rms_norm_fuse_impl(
|
||||
constant ggml_metal_kargs_norm & args,
|
||||
device const char * src0,
|
||||
device const char * src1_0,
|
||||
device const char * src1_1,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
const int i01 = tgpig.x;
|
||||
const int i02 = tgpig.y;
|
||||
const int i03 = tgpig.z;
|
||||
|
||||
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
|
||||
|
||||
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
|
||||
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float mean = sumf/args.ne00;
|
||||
const float scale = 1.0f/sqrt(mean + args.eps);
|
||||
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
|
||||
if (F == 1) {
|
||||
y[i00] = (x[i00]*scale);
|
||||
}
|
||||
if (F == 2) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00];
|
||||
}
|
||||
if (F == 3) {
|
||||
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
|
||||
|
||||
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
|
||||
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
|
||||
|
||||
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
|
||||
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_l2_norm_impl(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
const float scale = 1.0f/max(sqrt(sumf), args.eps);
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
||||
|
||||
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
||||
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ne = args.ne00*args.ne01*args.ne02;
|
||||
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
|
||||
|
||||
int start = tgpig * gs;
|
||||
int end = start + gs;
|
||||
|
||||
start += tpitg;
|
||||
|
||||
if (end >= ne) {
|
||||
end = ne;
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
tmp += src0[j];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / gs;
|
||||
tmp = 0.0f;
|
||||
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
float xi = src0[j] - mean;
|
||||
dst[j] = xi;
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
tmp = simd_sum(tmp);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = tmp;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
tmp = buf[tiisg];
|
||||
tmp = simd_sum(tmp);
|
||||
}
|
||||
|
||||
const float variance = tmp / gs;
|
||||
const float scale = 1.0f/sqrt(variance + args.eps);
|
||||
for (int j = start; j < end; j += ntg) {
|
||||
dst[j] *= scale;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_pool_2d_max_f32(
|
||||
constant ggml_metal_kargs_pool_2d & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = args.IH * args.IW;
|
||||
const int O_HW = args.OH * args.OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / args.OW;
|
||||
const int cur_ow = idx % O_HW % args.OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * args.s1 - args.p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(args.IH, start_h + args.k1);
|
||||
const int start_w = cur_ow * args.s0 - args.p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(args.IW, start_w + args.k0);
|
||||
|
||||
float res = -INFINITY;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
res = MAX(res, i_ptr[i * args.IW + j]);
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_2d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_2d & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idx = gid;
|
||||
const int I_HW = args.IH * args.IW;
|
||||
const int O_HW = args.OH * args.OW;
|
||||
const int nc = idx / O_HW;
|
||||
const int cur_oh = idx % O_HW / args.OW;
|
||||
const int cur_ow = idx % O_HW % args.OW;
|
||||
|
||||
device const float * i_ptr = src0 + nc * I_HW;
|
||||
device float * o_ptr = dst + nc * O_HW;
|
||||
|
||||
const int start_h = cur_oh * args.s1 - args.p1;
|
||||
const int bh = MAX(0, start_h);
|
||||
const int eh = MIN(args.IH, start_h + args.k1);
|
||||
const int start_w = cur_ow * args.s0 - args.p0;
|
||||
const int bw = MAX(0, start_w);
|
||||
const int ew = MIN(args.IW, start_w + args.k0);
|
||||
// const float scale = 1. / ((eh - bh) * (ew - bw));
|
||||
const float scale = 1. / (args.k0 * args.k1);
|
||||
|
||||
float res = 0;
|
||||
|
||||
for (int i = bh; i < eh; i += 1) {
|
||||
for (int j = bw; j < ew; j += 1) {
|
||||
float cur = i_ptr[i * args.IW + j];
|
||||
res += cur * scale;
|
||||
}
|
||||
}
|
||||
|
||||
o_ptr[cur_oh * args.OW + cur_ow] = res;
|
||||
}
|
||||
|
||||
|
||||
kernel void kernel_pool_1d_max_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = -INFINITY;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
int j = base + ki;
|
||||
if (j < 0 || j >= args.IW){
|
||||
continue;
|
||||
}
|
||||
float v = src[src_off + j];
|
||||
acc = max(acc, v);
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = acc;
|
||||
}
|
||||
|
||||
kernel void kernel_pool_1d_avg_f32(
|
||||
constant ggml_metal_kargs_pool_1d & args,
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid [[thread_position_in_grid]]
|
||||
) {
|
||||
|
||||
if (gid >= args.np) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ow = (int)gid % args.OW;
|
||||
const int row = (int)gid / args.OW;
|
||||
|
||||
const int base = ow * args.s0 - args.p0;
|
||||
|
||||
float acc = 0.0f;
|
||||
int cnt = 0;
|
||||
|
||||
const int src_off = row * args.IW;
|
||||
const int dst_off = row * args.OW;
|
||||
|
||||
for (int ki = 0; ki < args.k0; ++ki) {
|
||||
const int j = base + ki;
|
||||
if (j < 0 || j >= args.IW) {
|
||||
continue;
|
||||
}
|
||||
acc += src[src_off + j];
|
||||
cnt += 1;
|
||||
}
|
||||
|
||||
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
|
||||
}
|
||||
@@ -0,0 +1,213 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
sum_abs += fabs(src[j]);
|
||||
}
|
||||
dst.d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; j++) {
|
||||
dst.qs[j] = 0;
|
||||
}
|
||||
for (int j = 0; j < QK1_0; j++) {
|
||||
if (src[j] >= 0.0f) {
|
||||
dst.qs[j / 8] |= (1 << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -8;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK4_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
|
||||
for (int j = 0; j < QK4_1; j++) {
|
||||
const float v = src[j];
|
||||
if (min > v) min = v;
|
||||
if (max < v) max = v;
|
||||
}
|
||||
|
||||
const float d = (max - min) / ((1 << 4) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
for (int j = 0; j < QK4_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK4_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
|
||||
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
|
||||
|
||||
dst.qs[j] = xi0;
|
||||
dst.qs[j] |= xi1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK5_0; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / -16;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_0/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK5_0/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
|
||||
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float max = src[0];
|
||||
float min = src[0];
|
||||
|
||||
for (int j = 1; j < QK5_1; j++) {
|
||||
const float v = src[j];
|
||||
min = v < min ? v : min;
|
||||
max = v > max ? v : max;
|
||||
}
|
||||
|
||||
const float d = (max - min) / 31;
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
dst.m = min;
|
||||
|
||||
uint32_t qh = 0;
|
||||
for (int j = 0; j < QK5_1/2; ++j) {
|
||||
const float x0 = (src[0 + j] - min)*id;
|
||||
const float x1 = (src[QK5_1/2 + j] - min)*id;
|
||||
|
||||
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
|
||||
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
|
||||
|
||||
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
|
||||
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
|
||||
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
|
||||
}
|
||||
|
||||
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
dst.qh[j] = qh8[j];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
const float v = src[j];
|
||||
amax = MAX(amax, fabs(v));
|
||||
}
|
||||
|
||||
const float d = amax / ((1 << 7) - 1);
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
dst.d = d;
|
||||
|
||||
for (int j = 0; j < QK8_0; ++j) {
|
||||
const float x0 = src[j]*id;
|
||||
|
||||
dst.qs[j] = round(x0);
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
|
||||
#pragma METAL fp math_mode(safe)
|
||||
float amax = 0.0f; // absolute max
|
||||
float max = 0.0f;
|
||||
|
||||
for (int j = 0; j < QK4_NL; j++) {
|
||||
const float v = src[j];
|
||||
if (amax < fabs(v)) {
|
||||
amax = fabs(v);
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
const float d = max / kvalues_iq4nl_f[0];
|
||||
const float id = d ? 1.0f/d : 0.0f;
|
||||
|
||||
float sumqx = 0, sumq2 = 0;
|
||||
for (int j = 0; j < QK4_NL/2; ++j) {
|
||||
const float x0 = src[0 + j]*id;
|
||||
const float x1 = src[QK4_NL/2 + j]*id;
|
||||
|
||||
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
|
||||
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
|
||||
|
||||
dst.qs[j] = xi0 | (xi1 << 4);
|
||||
|
||||
const float v0 = kvalues_iq4nl_f[xi0];
|
||||
const float v1 = kvalues_iq4nl_f[xi1];
|
||||
const float w0 = src[0 + j]*src[0 + j];
|
||||
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
|
||||
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
|
||||
sumq2 += w0*v0*v0 + w1*v1*v1;
|
||||
|
||||
}
|
||||
|
||||
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
#include "common.h"
|
||||
#include "dequantize.h"
|
||||
#include "quantize.h"
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy_t_t(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
|
||||
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
||||
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
||||
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
||||
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
|
||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
|
||||
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
template<short QK,
|
||||
typename block_q,
|
||||
void (*quantize_func)(device const float *, device block_q &)>
|
||||
kernel void kernel_cpy_f32_q(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
||||
|
||||
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
||||
|
||||
quantize_func(src, dst_data[i00]);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
|
||||
|
||||
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_cpy_q_f32(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig[2];
|
||||
const int32_t i02 = tgpig[1];
|
||||
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
|
||||
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
||||
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
||||
|
||||
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
|
||||
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
|
||||
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
|
||||
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
|
||||
|
||||
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
|
||||
T4x4 temp;
|
||||
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
||||
dst_data[i00] = temp;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_concat(
|
||||
constant ggml_metal_kargs_concat & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
|
||||
|
||||
if (i1 >= args.ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int o[4] = {0, 0, 0, 0};
|
||||
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
device const T * x;
|
||||
|
||||
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
|
||||
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
|
||||
} else {
|
||||
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
|
||||
}
|
||||
|
||||
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_concat<float>) kernel_concat_t;
|
||||
|
||||
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
|
||||
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
|
||||
#endif
|
||||
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
|
||||
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
|
||||
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
|
||||
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows_q(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device void * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg [[threads_per_threadgroup]]) {
|
||||
const int32_t iw0 = tgpig.x/args.ne10;
|
||||
const int32_t i10 = tgpig.x%args.ne10;
|
||||
const int32_t i11 = tgpig.y;
|
||||
const int32_t i12 = tgpig.z;
|
||||
|
||||
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
||||
|
||||
const int32_t i02 = i11;
|
||||
const int32_t i03 = i12;
|
||||
|
||||
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
||||
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
||||
|
||||
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
||||
float4x4 temp;
|
||||
dequantize_func(psrc + ind/nl, ind%nl, temp);
|
||||
pdst[ind] = temp;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T0, typename T>
|
||||
kernel void kernel_get_rows_f(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device void * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 ntg [[threads_per_threadgroup]]) {
|
||||
const int32_t iw0 = tgpig.x/args.ne10;
|
||||
const int32_t i10 = tgpig.x%args.ne10;
|
||||
const int32_t i11 = tgpig.y;
|
||||
const int32_t i12 = tgpig.z;
|
||||
|
||||
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
||||
|
||||
const int32_t i02 = i11;
|
||||
const int32_t i03 = i12;
|
||||
|
||||
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
||||
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
||||
|
||||
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
||||
pdst[ind] = psrc[ind];
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
||||
kernel void kernel_set_rows_q32(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
quantize_func(src_row + 32*ind, dst_row[ind]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TI>
|
||||
kernel void kernel_set_rows_f(
|
||||
constant ggml_metal_kargs_set_rows & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
|
||||
const int32_t i12 = i03%args.ne12;
|
||||
const int32_t i11 = i02%args.ne11;
|
||||
|
||||
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
|
||||
if (i01 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t i10 = i01;
|
||||
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
||||
|
||||
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
|
||||
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
|
||||
dst_row[ind] = (T) src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// get rows
|
||||
//
|
||||
|
||||
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
|
||||
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// set rows
|
||||
//
|
||||
|
||||
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
|
||||
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
|
||||
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
|
||||
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
|
||||
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
|
||||
#endif
|
||||
|
||||
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
||||
|
||||
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
|
||||
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
|
||||
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
|
||||
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
|
||||
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
|
||||
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
|
||||
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
|
||||
@@ -0,0 +1,228 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_op_sum_f32(
|
||||
constant ggml_metal_kargs_sum & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
if (args.np == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: become function constant
|
||||
const uint nsg = (ntg.x + 31) / 32;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||
sumf += src0[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float total = 0;
|
||||
|
||||
if (sgitg == 0) {
|
||||
float v = 0;
|
||||
|
||||
if (tpitg.x < nsg) {
|
||||
v = shmem_f32[tpitg.x];
|
||||
}
|
||||
|
||||
total = simd_sum(v);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[0] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_sum_rows_impl(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_sum_rows_op
|
||||
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_t[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
T0 sumf = T0(0.0f);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_t[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_t[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
||||
if (is_same<float4, T0>::value) {
|
||||
dst_row[0] = sum(sumf) / (4*args.ne00);
|
||||
} else {
|
||||
dst_row[0] = sum(sumf) / args.ne00;
|
||||
}
|
||||
} else {
|
||||
dst_row[0] = sum(sumf);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
constant ggml_metal_kargs_cumsum_blk & args,
|
||||
device const char * src0,
|
||||
device char * tmp,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * src0_row = (device const float *) (src0 +
|
||||
args.nb01*i01 +
|
||||
args.nb02*i02 +
|
||||
args.nb03*i03);
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
v = src0_row[i00 + tpitg.x];
|
||||
}
|
||||
|
||||
float s = simd_prefix_inclusive_sum(v);
|
||||
|
||||
if (tiisg == N_SIMDWIDTH - 1) {
|
||||
shmem_f32[sgitg] = s;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
s += shmem_f32[sgitg];
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] = s;
|
||||
}
|
||||
|
||||
if (args.outb && tpitg.x == ntg.x - 1) {
|
||||
device float * tmp_row = (device float *) tmp +
|
||||
args.net0*i01 +
|
||||
args.net0*args.net1*i02 +
|
||||
args.net0*args.net1*args.net2*i03;
|
||||
|
||||
tmp_row[ib] = s;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_add(
|
||||
constant ggml_metal_kargs_cumsum_add & args,
|
||||
device const char * tmp,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int ib = tgpig[0]/args.ne01;
|
||||
|
||||
if (ib == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i00 = ib*ntg.x;
|
||||
const int i01 = tgpig[0]%args.ne01;
|
||||
const int i02 = tgpig[1];
|
||||
const int i03 = tgpig[2];
|
||||
|
||||
device const float * tmp_row = (device const float *) (tmp +
|
||||
args.nbt1*i01 +
|
||||
args.nbt2*i02 +
|
||||
args.nbt3*i03);
|
||||
|
||||
device float * dst_row = (device float *) dst +
|
||||
args.ne00*i01 +
|
||||
args.ne00*args.ne01*i02 +
|
||||
args.ne00*args.ne01*args.ne02*i03;
|
||||
|
||||
if (i00 + tpitg.x < args.ne00) {
|
||||
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
|
||||
|
||||
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
|
||||
@@ -0,0 +1,318 @@
|
||||
#include "common.h"
|
||||
|
||||
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
|
||||
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
|
||||
thread float * cos_theta, thread float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = cos(theta) * mscale;
|
||||
*sin_theta = sin(theta) * mscale;
|
||||
if (FC_rope_is_back) {
|
||||
*sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
||||
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
|
||||
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
|
||||
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
|
||||
}
|
||||
|
||||
static void rope_yarn_corr_dims(
|
||||
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
|
||||
) {
|
||||
// start and end correction dims
|
||||
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
|
||||
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_norm(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[1];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_neox(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float theta_base = (float) pos[i2];
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations
|
||||
// note: the rest is the same as kernel_rope_neox
|
||||
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
||||
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
||||
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base;
|
||||
if (FC_rope_is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
|
||||
theta_base = (float) pos[i2 + args.ne02 * 0];
|
||||
} else { // e
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
} else {
|
||||
if (sector < args.sect_0) {
|
||||
theta_base = (float) pos[i2];
|
||||
} else if (sector < sec_w01) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 1];
|
||||
} else if (sector < sec_w012) {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 2];
|
||||
} else {
|
||||
theta_base = (float) pos[i2 + args.ne02 * 3];
|
||||
}
|
||||
}
|
||||
// end of mrope
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
||||
const int ic = i0/2;
|
||||
|
||||
// mrope theta calculations (only support 2 dimensions)
|
||||
const int sect_dims = args.sect_0 + args.sect_1;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float p;
|
||||
float theta_base;
|
||||
if (sector < args.sect_1) {
|
||||
p = (float) sector;
|
||||
theta_base = (float) pos[i2];
|
||||
} else {
|
||||
p = (float) sector - args.sect_0;
|
||||
theta_base = (float) pos[i2 + args.ne02];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
||||
// end of mrope
|
||||
|
||||
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
@@ -0,0 +1,223 @@
|
||||
#include "common.h"
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
|
||||
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max_4(
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = max_val;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
max_val = buf[tiisg];
|
||||
max_val = simd_max(max_val);
|
||||
}
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||
|
||||
// This barrier fixes a failing test
|
||||
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
buf[sgitg] = sum;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sum = buf[tiisg];
|
||||
sum = simd_sum(sum);
|
||||
}
|
||||
|
||||
if (psrc2) {
|
||||
sum += exp(psrc2[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||
|
||||
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
|
||||
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
|
||||
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
|
||||
|
||||
kernel void kernel_solve_tri_f32(
|
||||
constant ggml_metal_kargs_solve_tri & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
ushort3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
const short NSG = FC_solve_tri_nsg;
|
||||
const short N = FC_solve_tri_n;
|
||||
const short K = FC_solve_tri_k;
|
||||
const short NP = PAD2(N, NW);
|
||||
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x*NSG + sgitg;
|
||||
|
||||
threadgroup float * sh0 = (threadgroup float *) shmem;
|
||||
|
||||
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
|
||||
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
|
||||
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
|
||||
|
||||
for (short rr = 0; rr < N; rr += NSG) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
threadgroup float * sh0_cur = sh0 + sgitg*NP;
|
||||
|
||||
for (short t = 0; t*NW < N; ++t) {
|
||||
const short idx = t*NW + tiisg;
|
||||
sh0_cur[idx] = src0_ptr[idx];
|
||||
}
|
||||
|
||||
src0_ptr += NSG*N;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (i01 >= args.ne10) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
|
||||
const short r = rr + ir;
|
||||
|
||||
threadgroup float * sh0_cur = sh0 + ir*NP;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (short t = 0; t*NW < r; ++t) {
|
||||
const short idx = t*NW + tiisg;
|
||||
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
|
||||
}
|
||||
|
||||
sum = simd_sum(sum);
|
||||
|
||||
if (tiisg == 0) {
|
||||
const float diag = sh0_cur[r];
|
||||
|
||||
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
#include "common.h"
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
||||
kernel void kernel_ssm_conv_f32_f32(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i3 = tgpig.z;
|
||||
|
||||
const int64_t nc = args.ne10;
|
||||
//const int64_t ncs = args.ne00;
|
||||
//const int64_t nr = args.ne01;
|
||||
//const int64_t n_t = args.ne1;
|
||||
//const int64_t n_s = args.ne2;
|
||||
|
||||
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
kernel void kernel_ssm_conv_f32_f32_4(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i3 = tgpig.z;
|
||||
|
||||
const int64_t nc = args.ne10;
|
||||
//const int64_t ncs = args.ne00;
|
||||
//const int64_t nr = args.ne01;
|
||||
//const int64_t n_t = args.ne1;
|
||||
//const int64_t n_s = args.ne2;
|
||||
|
||||
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
||||
sumf += dot(s[i0], c[i0]);
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
|
||||
|
||||
// Batched version: each threadgroup processes multiple tokens for better efficiency
|
||||
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
|
||||
kernel void kernel_ssm_conv_f32_f32_batched(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// tgpig.x = row index (ir)
|
||||
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||
// tgpig.z = sequence index (i3)
|
||||
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2_off = tpitg.x;
|
||||
const int64_t i2 = i2_base + i2_off;
|
||||
|
||||
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||
const int64_t n_t = args.ne1; // number of tokens
|
||||
|
||||
// Bounds check for partial batches at the end
|
||||
if (i2 >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load conv weights (shared across all tokens for this row)
|
||||
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
|
||||
|
||||
// Load source for this specific token
|
||||
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
|
||||
// Output location for this token
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
||||
sumf += s[i0] * c[i0];
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
kernel void kernel_ssm_conv_f32_f32_batched_4(
|
||||
constant ggml_metal_kargs_ssm_conv & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
// tgpig.x = row index (ir)
|
||||
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
|
||||
// tgpig.z = sequence index (i3)
|
||||
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
|
||||
const short BATCH_SIZE = FC_ssm_conv_bs;
|
||||
|
||||
const int64_t ir = tgpig.x;
|
||||
const int64_t i2_base = tgpig.y * BATCH_SIZE;
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2_off = tpitg.x;
|
||||
const int64_t i2 = i2_base + i2_off;
|
||||
|
||||
const int64_t nc = args.ne10; // conv kernel size (typically 4)
|
||||
const int64_t n_t = args.ne1; // number of tokens
|
||||
|
||||
// Bounds check for partial batches at the end
|
||||
if (i2 >= n_t) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Load conv weights (shared across all tokens for this row)
|
||||
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
||||
|
||||
// Load source for this specific token
|
||||
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
||||
|
||||
// Output location for this token
|
||||
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
||||
sumf += dot(s[i0], c[i0]);
|
||||
}
|
||||
|
||||
x[0] = sumf;
|
||||
}
|
||||
|
||||
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
||||
// Optimized version: reduces redundant memory loads by having one thread load shared values
|
||||
kernel void kernel_ssm_scan_f32(
|
||||
constant ggml_metal_kargs_ssm_scan & args,
|
||||
device const void * src0,
|
||||
device const void * src1,
|
||||
device const void * src2,
|
||||
device const void * src3,
|
||||
device const void * src4,
|
||||
device const void * src5,
|
||||
device const void * src6,
|
||||
device float * dst,
|
||||
threadgroup float * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgptg[[simdgroups_per_threadgroup]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]) {
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
|
||||
// Shared memory layout:
|
||||
// [0..sgptg*NW-1]: partial sums for reduction (existing)
|
||||
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
|
||||
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
|
||||
threadgroup float * shared_sums = shared;
|
||||
threadgroup float * shared_x_dt = shared + sgptg * NW;
|
||||
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
|
||||
|
||||
shared_sums[tpitg.x] = 0.0f;
|
||||
|
||||
const int32_t i0 = tpitg.x;
|
||||
const int32_t i1 = tgpig.x;
|
||||
const int32_t ir = tgpig.y; // current head
|
||||
const int32_t i3 = tgpig.z; // current seq
|
||||
|
||||
const int32_t nc = args.d_state;
|
||||
const int32_t nr = args.d_inner;
|
||||
const int32_t nh = args.n_head;
|
||||
const int32_t ng = args.n_group;
|
||||
const int32_t n_t = args.n_seq_tokens;
|
||||
|
||||
const int32_t s_off = args.s_off;
|
||||
|
||||
device const int32_t * ids = (device const int32_t *) src6;
|
||||
|
||||
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
||||
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
||||
|
||||
const int32_t i = i0 + i1*nc;
|
||||
const int32_t g = ir / (nh / ng); // repeat_interleave
|
||||
|
||||
float s0 = s0_buff[i];
|
||||
float s = 0.0f;
|
||||
|
||||
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
|
||||
|
||||
const float A0 = A[i0%args.ne30];
|
||||
|
||||
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
|
||||
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
|
||||
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
|
||||
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
|
||||
|
||||
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
|
||||
|
||||
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// Pre-compute x_dt and dA for this batch of tokens
|
||||
// Only first sgptg threads do the loads and expensive math
|
||||
if (i0 < sgptg && i2 + i0 < n_t) {
|
||||
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
|
||||
device const float * x_t = x + i0 * args.ns12;
|
||||
device const float * dt_t = dt + i0 * args.ns21;
|
||||
|
||||
const float dt0 = dt_t[0];
|
||||
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
||||
shared_x_dt[i0] = x_t[0] * dtsp;
|
||||
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
||||
const float x_dt = shared_x_dt[t];
|
||||
const float dA = exp(shared_dA[t] * A0);
|
||||
|
||||
s = (s0 * dA) + (B[i0] * x_dt);
|
||||
|
||||
const float sumf = simd_sum(s * C[i0]);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shared_sums[t*NW + sgitg] = sumf;
|
||||
}
|
||||
|
||||
// recurse
|
||||
s0 = s;
|
||||
|
||||
B += args.ns42;
|
||||
C += args.ns52;
|
||||
}
|
||||
|
||||
// Advance pointers for next batch
|
||||
x += sgptg * args.ns12;
|
||||
dt += sgptg * args.ns21;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
|
||||
|
||||
if (tiisg == 0 && i2 + sgitg < n_t) {
|
||||
y[sgitg*nh*nr] = sumf;
|
||||
}
|
||||
|
||||
y += sgptg*nh*nr;
|
||||
}
|
||||
|
||||
s_buff[i] = s;
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
#include "common.h"
|
||||
|
||||
template<uint32_t ttype>
|
||||
bool _ggml_vec_tri_cmp(const int i, const int r);
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
|
||||
return i < r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
|
||||
return i <= r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
|
||||
return i > r;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
|
||||
return i >= r;
|
||||
}
|
||||
|
||||
template<typename T, int ttype>
|
||||
kernel void kernel_tri(
|
||||
constant ggml_metal_kargs_tri & args,
|
||||
device const char * src0,
|
||||
device const char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
// Each thread is a single element of the row if ne00 < max threads per
|
||||
// threadgroup, so this will loop once for each index that this thread is
|
||||
// responsible for
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
// Use the comparison as a mask for branchless
|
||||
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
|
||||
|
||||
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
|
||||
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
|
||||
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
|
||||
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
|
||||
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
|
||||
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
|
||||
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
|
||||
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
|
||||
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
|
||||
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
|
||||
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
|
||||
#endif
|
||||
@@ -0,0 +1,360 @@
|
||||
#include "common.h"
|
||||
|
||||
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
||||
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
||||
|
||||
template <typename T0, typename T, typename TC>
|
||||
kernel void kernel_unary_impl(
|
||||
constant ggml_metal_kargs_unary & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_unary_op
|
||||
#define FC_CNT FC_unary_cnt
|
||||
|
||||
device const T0 * src0_ptr;
|
||||
device T * dst_ptr;
|
||||
|
||||
int i0;
|
||||
|
||||
if (FC_CNT) {
|
||||
i0 = tgpig.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0);
|
||||
dst_ptr = (device T *) (dst);
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int k0 = tgpig.x/args.ne01;
|
||||
const int i01 = tgpig.x - k0*args.ne01;
|
||||
|
||||
i0 = k0*ntg.x + tpitg.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
||||
}
|
||||
|
||||
{
|
||||
//threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!FC_CNT) {
|
||||
if (i0 >= args.ne0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const TC x = (TC) src0_ptr[i0];
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
||||
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FILL) {
|
||||
dst_ptr[i0] = (T) args.val;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
||||
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQR) {
|
||||
dst_ptr[i0] = (T) (x * x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
||||
dst_ptr[i0] = (T) sqrt(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIN) {
|
||||
dst_ptr[i0] = (T) sin(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_COS) {
|
||||
dst_ptr[i0] = (T) cos(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LOG) {
|
||||
dst_ptr[i0] = (T) log(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
||||
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TANH) {
|
||||
dst_ptr[i0] = (T) precise::tanh(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_RELU) {
|
||||
dst_ptr[i0] = (T) fmax(0, x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
||||
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
||||
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SILU) {
|
||||
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ELU) {
|
||||
dst_ptr[i0] = (T) elu_approx(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_NEG) {
|
||||
dst_ptr[i0] = (T) -x;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ABS) {
|
||||
dst_ptr[i0] = (T) fabs(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SGN) {
|
||||
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_STEP) {
|
||||
dst_ptr[i0] = T(x > 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
||||
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
||||
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXP) {
|
||||
dst_ptr[i0] = (T) exp(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
||||
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
||||
// TODO: precise implementation
|
||||
dst_ptr[i0] = (T) (exp(x) - 1);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FLOOR) {
|
||||
dst_ptr[i0] = (T) floor(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CEIL) {
|
||||
dst_ptr[i0] = (T) ceil(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ROUND) {
|
||||
dst_ptr[i0] = (T) round(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TRUNC) {
|
||||
dst_ptr[i0] = (T) trunc(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_XIELU) {
|
||||
const TC xi = x;
|
||||
const TC gate = TC(xi > TC(0.0f));
|
||||
const TC clamped = fmin(xi, TC(args.val));
|
||||
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
|
||||
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
|
||||
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_CNT
|
||||
}
|
||||
|
||||
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
||||
|
||||
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
||||
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
||||
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
||||
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_reglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
|
||||
|
||||
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
|
||||
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||||
|
||||
dst_row[i0] = (T)(gelu*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
|
||||
|
||||
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
|
||||
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_swiglu(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float silu = x0 / (1.0f + exp(-x0));
|
||||
|
||||
dst_row[i0] = (T)(silu*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
|
||||
|
||||
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
|
||||
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_swiglu_oai(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
float x0 = src0_row[i0];
|
||||
float x1 = src1_row[i0];
|
||||
|
||||
x0 = min(x0, args.limit);
|
||||
x1 = max(min(x1, args.limit), -args.limit);
|
||||
|
||||
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
|
||||
out_glu = out_glu * (1.0f + x1);
|
||||
|
||||
dst_row[i0] = (T)out_glu;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
|
||||
|
||||
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
|
||||
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu_erf(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
|
||||
|
||||
dst_row[i0] = (T)(gelu_erf*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
|
||||
|
||||
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
|
||||
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_geglu_quick(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
|
||||
|
||||
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||
const float x0 = src0_row[i0];
|
||||
const float x1 = src1_row[i0];
|
||||
|
||||
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
|
||||
|
||||
dst_row[i0] = (T)(gelu_quick*x1);
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
|
||||
|
||||
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
|
||||
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
|
||||
@@ -0,0 +1,179 @@
|
||||
#include "common.h"
|
||||
|
||||
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
|
||||
|
||||
kernel void kernel_upscale_nearest_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3/args.sf3;
|
||||
const int64_t i02 = i2/args.sf2;
|
||||
const int64_t i01 = i1/args.sf1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const int64_t i00 = i0/args.sf0;
|
||||
|
||||
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
||||
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_ptr[0] = src0_ptr[0];
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bilinear_tri(float x) {
|
||||
return MAX(0.0f, 1.0f - fabs(x));
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bilinear_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
|
||||
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
|
||||
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
|
||||
|
||||
src0 += i03*args.nb03 + i02*args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
|
||||
|
||||
if (FC_upscale_aa) {
|
||||
const float support0 = MAX(1.0f, 1.0f / args.sf0);
|
||||
const float invscale0 = 1.0f / support0;
|
||||
const float support1 = MAX(1.0f, 1.0f / args.sf1);
|
||||
const float invscale1 = 1.0f / support1;
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
|
||||
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
|
||||
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
|
||||
|
||||
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
|
||||
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
|
||||
|
||||
float sum = 0.0f;
|
||||
float wsum = 0.0f;
|
||||
|
||||
for (int64_t sy = y_min; sy < y_max; ++sy) {
|
||||
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
|
||||
for (int64_t sx = x_min; sx < x_max; ++sx) {
|
||||
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
|
||||
const float w = wx * wy;
|
||||
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
|
||||
sum += (*src_ptr) * w;
|
||||
wsum += w;
|
||||
}
|
||||
}
|
||||
|
||||
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
} else {
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
|
||||
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
|
||||
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
|
||||
|
||||
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
|
||||
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
|
||||
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
|
||||
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
|
||||
|
||||
const float v =
|
||||
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
|
||||
(*src10) * fd0 * (1.0f - fd1) +
|
||||
(*src01) * (1.0f - fd0) * fd1 +
|
||||
(*src11) * fd0 * fd1;
|
||||
|
||||
dst_ptr[i0] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline float bicubic_weight1(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a + 2) * x - (a + 3)) * x * x + 1;
|
||||
}
|
||||
|
||||
static inline float bicubic_weight2(float x) {
|
||||
const float a = -0.75f;
|
||||
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
|
||||
}
|
||||
|
||||
kernel void kernel_upscale_bicubic_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t i3 = tgpig.z;
|
||||
const int64_t i2 = tgpig.y;
|
||||
const int64_t i1 = tgpig.x;
|
||||
|
||||
const int64_t i03 = i3 / args.sf3;
|
||||
const int64_t i02 = i2 / args.sf2;
|
||||
|
||||
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
|
||||
const int64_t i01 = (int64_t)floor(f01);
|
||||
const float fd1 = f01 - (float)i01;
|
||||
|
||||
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
|
||||
const float w_y1 = bicubic_weight1(fd1);
|
||||
const float w_y2 = bicubic_weight1(1.0f - fd1);
|
||||
const float w_y3 = bicubic_weight2(2.0f - fd1);
|
||||
|
||||
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
|
||||
|
||||
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
|
||||
|
||||
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
|
||||
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
|
||||
const int64_t i00 = (int64_t)floor(f00);
|
||||
const float fd0 = f00 - (float)i00;
|
||||
|
||||
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
|
||||
const float w_x1 = bicubic_weight1(fd0);
|
||||
const float w_x2 = bicubic_weight1(1.0f - fd0);
|
||||
const float w_x3 = bicubic_weight2(2.0f - fd0);
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int dy = -1; dy <= 2; ++dy) {
|
||||
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
|
||||
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
|
||||
|
||||
for (int dx = -1; dx <= 2; ++dx) {
|
||||
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
|
||||
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
|
||||
|
||||
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
|
||||
sum += (*src_ptr) * wx * wy;
|
||||
}
|
||||
}
|
||||
|
||||
dst_ptr[i0] = sum;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
#include "common.h"
|
||||
|
||||
kernel void kernel_rwkv_wkv6_f32(
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * r,
|
||||
device const float * tf,
|
||||
device const float * td,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _tf[head_size];
|
||||
threadgroup float _td[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_tf[tid] = tf[head_id * head_size + tid];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0;
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
float4 temp = tf_vec * kv + s_vec;
|
||||
y += dot(r_vec, temp);
|
||||
|
||||
s_vec = s_vec * td_vec + kv;
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rwkv_wkv7_f32(
|
||||
device const float * r,
|
||||
device const float * w,
|
||||
device const float * k,
|
||||
device const float * v,
|
||||
device const float * a,
|
||||
device const float * b,
|
||||
device const float * state_in,
|
||||
device float * dst,
|
||||
constant uint & B,
|
||||
constant uint & T,
|
||||
constant uint & C,
|
||||
constant uint & H,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const uint head_size = 64; // TODO: support head_size = 128
|
||||
const uint batch_id = tgpig.x / H;
|
||||
const uint head_id = tgpig.x % H;
|
||||
const uint tid = tpitg.x;
|
||||
|
||||
if (batch_id >= B || head_id >= H) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint state_size = C * head_size;
|
||||
const uint n_seq_tokens = T / B;
|
||||
|
||||
threadgroup float _r[head_size];
|
||||
threadgroup float _w[head_size];
|
||||
threadgroup float _k[head_size];
|
||||
threadgroup float _a[head_size];
|
||||
threadgroup float _b[head_size];
|
||||
|
||||
float state[head_size];
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i];
|
||||
}
|
||||
|
||||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
|
||||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
|
||||
|
||||
for (uint t = start_t; t < end_t; t += C) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
_r[tid] = r[t];
|
||||
_w[tid] = w[t];
|
||||
_k[tid] = k[t];
|
||||
_a[tid] = a[t];
|
||||
_b[tid] = b[t];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float v_val = v[t];
|
||||
float y = 0.0, sa = 0.0;
|
||||
|
||||
float4 sa_vec(0.0);
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
sa_vec += a_vec * s_vec;
|
||||
}
|
||||
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
|
||||
|
||||
for (uint j = 0; j < head_size; j += 4) {
|
||||
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
|
||||
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
|
||||
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
float4 kv = k_vec * v_val;
|
||||
|
||||
s_vec = s_vec * w_vec + kv + sa * b_vec;
|
||||
y += dot(s_vec, r_vec);
|
||||
|
||||
state[j] = s_vec[0];
|
||||
state[j+1] = s_vec[1];
|
||||
state[j+2] = s_vec[2];
|
||||
state[j+3] = s_vec[3];
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_id * state_size + head_id * head_size * head_size
|
||||
+ tid * head_size + i] = state[i];
|
||||
}
|
||||
}
|
||||
@@ -1270,77 +1270,14 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
|
||||
}
|
||||
|
||||
std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
|
||||
static const std::map<ggml_op, std::string> ops = {
|
||||
{GGML_OP_NONE, "GGML_OP_NONE" },
|
||||
{GGML_OP_ACC, "GGML_OP_ACC" },
|
||||
{GGML_OP_ADD, "GGML_OP_ADD" },
|
||||
{GGML_OP_ADD1, "GGML_OP_ADD1" },
|
||||
{GGML_OP_ADD_ID, "GGML_OP_ADD_ID" },
|
||||
{GGML_OP_CONCAT, "GGML_OP_CONCAT" },
|
||||
{GGML_OP_CONT, "GGML_OP_CONT" },
|
||||
{GGML_OP_DIV, "GGML_OP_DIV" },
|
||||
{GGML_OP_DUP, "GGML_OP_DUP" },
|
||||
{GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
|
||||
{GGML_OP_MUL, "GGML_OP_MUL" },
|
||||
{GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
|
||||
{GGML_OP_MUL_MAT_ID, "GGML_OP_MUL_MAT_ID" },
|
||||
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
|
||||
{GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
|
||||
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
|
||||
{GGML_OP_NORM, "GGML_OP_NORM" },
|
||||
{GGML_OP_ROPE, "GGML_OP_ROPE" },
|
||||
{GGML_OP_SCALE, "GGML_OP_SCALE" },
|
||||
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
|
||||
{GGML_OP_SUM_ROWS, "GGML_OP_SUM_ROWS" },
|
||||
{GGML_OP_SUB, "GGML_OP_SUB" },
|
||||
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" },
|
||||
{GGML_OP_VIEW, "GGML_OP_VIEW" },
|
||||
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
|
||||
{GGML_OP_CPY, "GGML_OP_CPY" },
|
||||
{GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT" },
|
||||
{GGML_OP_L2_NORM, "GGML_OP_L2_NORM" },
|
||||
{GGML_OP_CLAMP, "GGML_OP_CLAMP" },
|
||||
{GGML_OP_PAD, "GGML_OP_PAD" },
|
||||
{GGML_OP_SSM_CONV, "GGML_OP_SSM_CONV" },
|
||||
{GGML_OP_GATED_DELTA_NET, "GGML_OP_GATED_DELTA_NET"},
|
||||
{GGML_OP_ARGSORT, "GGML_OP_ARGSORT" },
|
||||
{GGML_OP_REPEAT, "GGML_OP_REPEAT" },
|
||||
{GGML_OP_IM2COL, "GGML_OP_IM2COL" }
|
||||
};
|
||||
static const std::map<ggml_unary_op, std::string> unary_ops = {
|
||||
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
|
||||
{GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" },
|
||||
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" },
|
||||
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" },
|
||||
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" },
|
||||
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" },
|
||||
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" },
|
||||
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" },
|
||||
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" },
|
||||
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" },
|
||||
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" },
|
||||
{GGML_UNARY_OP_SOFTPLUS, "GGML_UNARY_OP_SOFTPLUS" },
|
||||
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" },
|
||||
{GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
|
||||
{GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" },
|
||||
{GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" }
|
||||
};
|
||||
static const std::map<ggml_glu_op, std::string> glu_ops = {
|
||||
{GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
|
||||
{GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" },
|
||||
{GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
|
||||
};
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_UNARY:
|
||||
return unary_ops.at(ggml_get_unary_op(node));
|
||||
return std::string("GGML_UNARY_OP_") + ggml_unary_op_name(ggml_get_unary_op(node));
|
||||
case GGML_OP_GLU:
|
||||
return glu_ops.at(ggml_get_glu_op(node));
|
||||
return std::string("GGML_GLU_OP_") + ggml_glu_op_name(ggml_get_glu_op(node));
|
||||
default:
|
||||
return ops.at(node->op);
|
||||
return std::string("GGML_OP_") + ggml_op_name(node->op);
|
||||
}
|
||||
static const std::string unknown_op = "UNKNOWN_GGML_OP";
|
||||
return unknown_op;
|
||||
}
|
||||
|
||||
const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
|
||||
|
||||
@@ -17,6 +17,22 @@ namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
static ov::Output<ov::Node> reshape_add_id_input_to_2d(const ov::Output<ov::Node> & input,
|
||||
const ov::PartialShape & input_shape,
|
||||
const std::vector<int> & dims) {
|
||||
const auto actual_shape = input.get_partial_shape();
|
||||
if (actual_shape.rank().is_static() && actual_shape.rank().get_length() == 2) {
|
||||
return input;
|
||||
}
|
||||
|
||||
if (input_shape.rank().is_static() && input_shape.rank().get_length() == 2) {
|
||||
return input;
|
||||
}
|
||||
|
||||
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, ov::element::i64);
|
||||
return std::make_shared<ov::op::v1::Reshape>(input, get_dimensions(shape, dims), false);
|
||||
}
|
||||
|
||||
OutputVector translate_add_id(const NodeContext & context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
|
||||
@@ -28,11 +44,9 @@ OutputVector translate_add_id(const NodeContext & context) {
|
||||
// input: [1, n_token, n_used, n_embd]
|
||||
// bias: [1, 1, n_expert, n_embd]
|
||||
// ids: [1, 1, n_token, n_used]
|
||||
auto bias_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(bias, ov::element::i64);
|
||||
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
|
||||
|
||||
bias = std::make_shared<ov::op::v1::Reshape>(bias, get_dimensions(bias_shape_4d, {2, 3}), false);
|
||||
ids = std::make_shared<ov::op::v1::Reshape>(ids, get_dimensions(ids_shape_4d, {2, 3}), false);
|
||||
// Model bias constants may already be stored as [n_expert, n_embd].
|
||||
bias = reshape_add_id_input_to_2d(bias, context.get_input_shape(1), {2, 3});
|
||||
ids = reshape_add_id_input_to_2d(ids, context.get_input_shape(2), {2, 3});
|
||||
|
||||
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
|
||||
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
#include "../utils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/op/add.hpp>
|
||||
#include <openvino/op/clamp.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/sigmoid.hpp>
|
||||
@@ -15,7 +18,7 @@ namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_glu_swiglu(const NodeContext & context) {
|
||||
static std::pair<ov::Output<ov::Node>, ov::Output<ov::Node>> get_glu_inputs(const NodeContext & context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
|
||||
ov::Output<ov::Node> src0;
|
||||
@@ -52,6 +55,12 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
|
||||
std::swap(src0, src1);
|
||||
}
|
||||
|
||||
return {src0, src1};
|
||||
}
|
||||
|
||||
OutputVector translate_glu_swiglu(const NodeContext & context) {
|
||||
auto [src0, src1] = get_glu_inputs(context);
|
||||
|
||||
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
|
||||
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
|
||||
@@ -59,6 +68,27 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
OutputVector translate_glu_swiglu_oai(const NodeContext & context) {
|
||||
auto [src0, src1] = get_glu_inputs(context);
|
||||
|
||||
const int32_t * params = context.get_output_op_params();
|
||||
const float alpha = reinterpret_cast<const float *>(params)[2];
|
||||
const float limit = reinterpret_cast<const float *>(params)[3];
|
||||
|
||||
auto gate = std::make_shared<ov::op::v0::Clamp>(src0, -std::numeric_limits<float>::infinity(), limit);
|
||||
auto alpha_const = ov::op::v0::Constant::create(ov::element::f32, {}, {alpha});
|
||||
auto scaled_gate = std::make_shared<ov::op::v1::Multiply>(gate, alpha_const);
|
||||
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(scaled_gate);
|
||||
auto out_glu = std::make_shared<ov::op::v1::Multiply>(gate, sigmoid);
|
||||
|
||||
auto up = std::make_shared<ov::op::v0::Clamp>(src1, -limit, limit);
|
||||
auto one = ov::op::v0::Constant::create(ov::element::f32, {}, {1.0f});
|
||||
auto up_plus_one = std::make_shared<ov::op::v1::Add>(up, one);
|
||||
auto res = std::make_shared<ov::op::v1::Multiply>(out_glu, up_plus_one);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace ggml
|
||||
} // namespace frontend
|
||||
|
||||
@@ -2,23 +2,135 @@
|
||||
#include "../op_table.h"
|
||||
#include "../utils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <openvino/op/bitwise_and.hpp>
|
||||
#include <openvino/op/bitwise_right_shift.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/matmul.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <vector>
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ov::op::v0::Constant> const_i64(const std::vector<int64_t> & values) {
|
||||
return ov::op::v0::Constant::create(ov::element::i64, ov::Shape{values.size()}, values);
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> slice_axis(const ov::Output<ov::Node> & input, int64_t axis, int64_t begin, int64_t end) {
|
||||
return std::make_shared<ov::op::v8::Slice>(input, const_i64({begin}), const_i64({end}), const_i64({1}),
|
||||
const_i64({axis}));
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> translate_mul_mat_id_mxfp4_packed(const NodeContext & context,
|
||||
ov::Output<ov::Node> expert_weights,
|
||||
ov::Output<ov::Node> activations,
|
||||
ov::Output<ov::Node> ids) {
|
||||
auto packed_shape = expert_weights.get_partial_shape().to_shape();
|
||||
FRONT_END_OP_CONVERSION_CHECK(packed_shape.size() == 5 && packed_shape[4] == 17,
|
||||
"Expected packed MXFP4 expert weights with shape [1, n_expert, m, k_blocks, 17]");
|
||||
|
||||
const int64_t n_expert = static_cast<int64_t>(packed_shape[1]);
|
||||
const int64_t rows = static_cast<int64_t>(packed_shape[2]);
|
||||
const int64_t k_blocks = static_cast<int64_t>(packed_shape[3]);
|
||||
const int64_t qk = 32;
|
||||
const int64_t cols = k_blocks * qk;
|
||||
|
||||
auto packed_shape_4d = const_i64({n_expert, rows, k_blocks, 17});
|
||||
expert_weights = std::make_shared<ov::op::v1::Reshape>(expert_weights, packed_shape_4d, false);
|
||||
|
||||
auto activations_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
|
||||
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
|
||||
auto activations_shape_3d = get_dimensions(activations_shape_4d, {1, 2, 3});
|
||||
auto ids_shape_2d = get_dimensions(ids_shape_4d, {2, 3});
|
||||
|
||||
activations = std::make_shared<ov::op::v1::Reshape>(activations, activations_shape_3d, false);
|
||||
ids = std::make_shared<ov::op::v1::Reshape>(ids, ids_shape_2d, false);
|
||||
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
|
||||
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
|
||||
}
|
||||
|
||||
auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
|
||||
|
||||
static const std::vector<float> f4e2m1_lut = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
|
||||
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
|
||||
std::vector<float> e8m0_lut(256);
|
||||
for (size_t i = 0; i < e8m0_lut.size(); ++i) {
|
||||
uint32_t bits = static_cast<uint32_t>(i) << 23;
|
||||
memcpy(&e8m0_lut[i], &bits, sizeof(float));
|
||||
}
|
||||
e8m0_lut[0] = std::numeric_limits<float>::min() / 2.0f;
|
||||
e8m0_lut[255] = std::numeric_limits<float>::quiet_NaN();
|
||||
|
||||
auto f4_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{f4e2m1_lut.size()}, f4e2m1_lut);
|
||||
auto scale_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{e8m0_lut.size()}, e8m0_lut);
|
||||
|
||||
auto selected_packed_weights = std::make_shared<ov::op::v8::Gather>(expert_weights, ids, gather_axis);
|
||||
auto scale_byte = slice_axis(selected_packed_weights, 4, 0, 1);
|
||||
auto qs = slice_axis(selected_packed_weights, 4, 1, 17);
|
||||
auto low = std::make_shared<ov::op::v13::BitwiseAnd>(
|
||||
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {0x0F}), ov::op::AutoBroadcastType::NUMPY);
|
||||
auto high_shift = std::make_shared<ov::op::v15::BitwiseRightShift>(
|
||||
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {4}), ov::op::AutoBroadcastType::NUMPY);
|
||||
auto nibbles = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{low, high_shift}, 4);
|
||||
auto nibble_indices = std::make_shared<ov::op::v0::Convert>(nibbles, ov::element::i32);
|
||||
auto weights_f32 = std::make_shared<ov::op::v8::Gather>(f4_lut, nibble_indices, gather_axis);
|
||||
|
||||
auto scale_indices = std::make_shared<ov::op::v0::Convert>(scale_byte, ov::element::i32);
|
||||
auto scales_f32 = std::make_shared<ov::op::v8::Gather>(scale_lut, scale_indices, gather_axis);
|
||||
ov::Output<ov::Node> selected_weights = std::make_shared<ov::op::v1::Multiply>(weights_f32, scales_f32,
|
||||
ov::op::AutoBroadcastType::NUMPY);
|
||||
|
||||
auto ids_shape = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
|
||||
auto selected_weights_target_dims = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{get_dimensions(ids_shape, {0, 1}), const_i64({rows, cols})}, 0);
|
||||
selected_weights = std::make_shared<ov::op::v1::Reshape>(selected_weights, selected_weights_target_dims, false);
|
||||
|
||||
auto activations_shape = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
|
||||
ov::Output<ov::Node> acts_target_dims = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{
|
||||
get_dimensions(activations_shape, {0}),
|
||||
get_dimensions(ids_shape, {1}),
|
||||
get_dimensions(activations_shape, {2}),
|
||||
},
|
||||
0);
|
||||
ov::Output<ov::Node> acts_broadcasted =
|
||||
std::make_shared<ov::op::v3::Broadcast>(activations, acts_target_dims, ov::op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
auto activations_expanded = std::make_shared<ov::op::v0::Unsqueeze>(acts_broadcasted, const_i64({2}));
|
||||
ov::Output<ov::Node> result =
|
||||
std::make_shared<ov::op::v0::MatMul>(activations_expanded, selected_weights, false, true);
|
||||
|
||||
auto batch_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto row_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {rows});
|
||||
auto result_target_dims = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{batch_dim, get_dimensions(ids_shape, {0, 1}), row_dim}, 0);
|
||||
result = std::make_shared<ov::op::v1::Reshape>(result, result_target_dims, false);
|
||||
|
||||
const auto output_type = context.get_output_type();
|
||||
if (result.get_element_type() != output_type) {
|
||||
result = std::make_shared<ov::op::v0::Convert>(result, output_type);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_mul_mat_id(const NodeContext & context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
|
||||
@@ -26,6 +138,12 @@ OutputVector translate_mul_mat_id(const NodeContext & context) {
|
||||
auto activations = process_view_input_new(context, 1);
|
||||
auto ids = process_view_input_new(context, 2);
|
||||
|
||||
if (expert_weights.get_element_type() == ov::element::u8 && expert_weights.get_partial_shape().rank().is_static() &&
|
||||
expert_weights.get_partial_shape().rank().get_length() == 5) {
|
||||
return rename_outputs_with_suffix({translate_mul_mat_id_mxfp4_packed(context, expert_weights, activations, ids)},
|
||||
context.get_name());
|
||||
}
|
||||
|
||||
// OpenVINO sees GGML tensors in reversed dimension order:
|
||||
// weights: [1, n_expert, m, k]
|
||||
// activations: [1, n_tokens, n_used_or_1, k]
|
||||
|
||||
@@ -6,12 +6,16 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/op/add.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/softmax.hpp>
|
||||
#include <vector>
|
||||
|
||||
@@ -20,12 +24,31 @@ namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
static bool is_static_one(const ov::Dimension & dim) {
|
||||
return dim.is_static() && dim.get_length() == 1;
|
||||
}
|
||||
|
||||
static bool same_static_dim(const ov::Dimension & lhs, const ov::Dimension & rhs) {
|
||||
return lhs.is_static() && rhs.is_static() && lhs.get_length() == rhs.get_length();
|
||||
}
|
||||
|
||||
static bool is_attention_sinks_input_shape(const ov::PartialShape & candidate, const ov::PartialShape & logits_shape) {
|
||||
if (candidate.rank().is_dynamic() || logits_shape.rank().is_dynamic() || candidate.rank().get_length() != 4 ||
|
||||
logits_shape.rank().get_length() != 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return is_static_one(candidate[0]) && is_static_one(candidate[1]) && is_static_one(candidate[2]) &&
|
||||
same_static_dim(candidate[3], logits_shape[1]);
|
||||
}
|
||||
|
||||
// Reimplementation of GGML_OP_SOFT_MAX semantics for OpenVINO backend:
|
||||
// 1) logits = src0 * scale
|
||||
// 2) logits += mask (if provided)
|
||||
// 3) softmax over the last dimension
|
||||
// 3) append attention sinks as hidden logits (if provided)
|
||||
// 4) softmax over the last dimension and remove the hidden sink column
|
||||
OutputVector translate_soft_max(const NodeContext & context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
num_inputs_check(context, 1, 3);
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
@@ -33,6 +56,11 @@ OutputVector translate_soft_max(const NodeContext & context) {
|
||||
memcpy(&max_bias, (float *) context.get_output_op_params() + 1, sizeof(float));
|
||||
|
||||
ov::Output<ov::Node> logits = context.get_input(0);
|
||||
const bool second_input_is_sinks =
|
||||
context.get_input_size() == 2 && is_attention_sinks_input_shape(context.get_input_shape(1), context.get_output_shape());
|
||||
const bool has_mask = context.get_input_size() > 1 && !second_input_is_sinks;
|
||||
const bool has_sinks = second_input_is_sinks || context.get_input_size() > 2;
|
||||
const size_t sinks_input_idx = second_input_is_sinks ? 1 : 2;
|
||||
|
||||
// Apply scale first: logits = src0 * scale
|
||||
if (scale != 1.0f) {
|
||||
@@ -41,12 +69,12 @@ OutputVector translate_soft_max(const NodeContext & context) {
|
||||
logits = std::make_shared<ov::op::v1::Multiply>(logits, scale_const);
|
||||
}
|
||||
|
||||
FRONT_END_CHECK_IMPLEMENTED(!(max_bias > 0.0f && context.get_input_size() < 2),
|
||||
FRONT_END_CHECK_IMPLEMENTED(!(max_bias > 0.0f && !has_mask),
|
||||
"OpenVINO softmax ALiBi path requires mask input");
|
||||
|
||||
// Optional mask add: logits += mask
|
||||
// For max_bias > 0 (ALiBi), apply per-head slope to mask before adding.
|
||||
if (context.get_input_size() > 1) {
|
||||
if (has_mask) {
|
||||
ov::Output<ov::Node> mask = context.get_input(1);
|
||||
|
||||
// For stateful
|
||||
@@ -94,8 +122,40 @@ OutputVector translate_soft_max(const NodeContext & context) {
|
||||
logits = std::make_shared<ov::op::v1::Add>(logits, mask);
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> softmax_input = logits;
|
||||
if (has_sinks) {
|
||||
ov::Output<ov::Node> sinks = context.get_input(sinks_input_idx);
|
||||
if (sinks.get_element_type() != logits.get_element_type()) {
|
||||
sinks = std::make_shared<ov::op::v0::Convert>(sinks, logits.get_element_type());
|
||||
}
|
||||
|
||||
auto sink_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, {1, -1, 1, 1});
|
||||
auto sinks_4d = std::make_shared<ov::op::v1::Reshape>(sinks, sink_shape, false);
|
||||
|
||||
auto logits_shape = std::make_shared<ov::op::v3::ShapeOf>(logits, ov::element::i64);
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
auto four = ov::op::v0::Constant::create(ov::element::i64, {1}, {4});
|
||||
auto shape_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
|
||||
auto sink_prefix_shape = std::make_shared<ov::op::v8::Slice>(logits_shape, zero, three, one, shape_axis);
|
||||
auto sink_last_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto sink_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{sink_prefix_shape, sink_last_dim}, 0);
|
||||
auto sink_column = std::make_shared<ov::op::v3::Broadcast>(sinks_4d, sink_broadcast_shape,
|
||||
ov::op::BroadcastType::BIDIRECTIONAL);
|
||||
softmax_input = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{logits, sink_column}, 3);
|
||||
|
||||
auto softmax_with_sink = std::make_shared<ov::op::v8::Softmax>(softmax_input, -1);
|
||||
auto original_last_dim = std::make_shared<ov::op::v8::Slice>(logits_shape, three, four, one, shape_axis);
|
||||
auto res = std::make_shared<ov::op::v8::Slice>(softmax_with_sink, zero, original_last_dim, one, three);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
// Softmax along last dimension (equivalent to ggml softmax over ne[0]).
|
||||
auto res = std::make_shared<ov::op::v8::Softmax>(logits, -1);
|
||||
auto res = std::make_shared<ov::op::v8::Softmax>(softmax_input, -1);
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"GGML_UNARY_OP_TANH", op::translate_1to1_match_1_input<v0::Tanh> },
|
||||
{"GGML_OP_VIEW", op::translate_view },
|
||||
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
|
||||
{"GGML_GLU_OP_SWIGLU_OAI", op::translate_glu_swiglu_oai },
|
||||
{"GGML_GLU_OP_GEGLU", op::translate_glu_geglu },
|
||||
{"GGML_OP_SET_ROWS", op::translate_set_rows },
|
||||
{"GGML_OP_CPY", op::translate_cpy },
|
||||
|
||||
@@ -32,6 +32,7 @@ GGML_OP_CONVERTER(translate_soft_max);
|
||||
GGML_OP_CONVERTER(translate_transpose);
|
||||
GGML_OP_CONVERTER(translate_view);
|
||||
GGML_OP_CONVERTER(translate_glu_swiglu);
|
||||
GGML_OP_CONVERTER(translate_glu_swiglu_oai);
|
||||
GGML_OP_CONVERTER(translate_glu_geglu);
|
||||
GGML_OP_CONVERTER(translate_set_rows);
|
||||
GGML_OP_CONVERTER(translate_cpy);
|
||||
|
||||
+102
-48
@@ -2,8 +2,10 @@
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
|
||||
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
||||
static void norm_f32(const float* x, float* dst, const int ncols,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
|
||||
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
||||
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
||||
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
|
||||
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
|
||||
|
||||
x += strided_offset;
|
||||
dst += packed_offset;
|
||||
x += src_offset;
|
||||
dst += dst_offset;
|
||||
|
||||
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
mean_var.x() += xi;
|
||||
mean_var.y() += xi * xi;
|
||||
}
|
||||
@@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
|
||||
const float inv_std = sycl::rsqrt(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = (x[col] - mean) * inv_std;
|
||||
dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
static void rms_norm_f32(const float* x, float* dst, const int ncols,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
|
||||
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
|
||||
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
|
||||
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
|
||||
|
||||
x += strided_offset;
|
||||
dst += packed_offset;
|
||||
x += src_offset;
|
||||
dst += dst_offset;
|
||||
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
||||
const float scale = sycl::rsqrt(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
|
||||
}
|
||||
}
|
||||
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
|
||||
const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
|
||||
const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
@@ -215,13 +220,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int sample = item_ct1.get_group(0);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row;
|
||||
dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[col];
|
||||
const float xi = x[col * src_stride_col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
@@ -229,12 +234,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[col] = scale * x[col];
|
||||
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
|
||||
}
|
||||
}
|
||||
|
||||
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
}
|
||||
|
||||
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
|
||||
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
|
||||
const float eps, queue_ptr stream, int device) {
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
@@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
rms_norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
rms_norm_f32(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
const int nrows,
|
||||
const int nchannels,
|
||||
const int nsamples,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const int64_t src_stride_col,
|
||||
const int64_t src_stride_row,
|
||||
const int64_t src_stride_channel,
|
||||
const int64_t src_stride_sample,
|
||||
const int64_t dst_stride_col,
|
||||
const int64_t dst_stride_row,
|
||||
const int64_t dst_stride_channel,
|
||||
const int64_t dst_stride_sample,
|
||||
const float eps,
|
||||
queue_ptr stream,
|
||||
int device) {
|
||||
@@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
l2_norm_f32<warp_size>(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1,
|
||||
nullptr, warp_size);
|
||||
});
|
||||
});
|
||||
@@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
|
||||
l2_norm_f32<warp_size>(x, dst, ncols,
|
||||
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
|
||||
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
@@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
@@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
@@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
const size_t tdst = ggml_type_size(dst->type);
|
||||
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
|
||||
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
|
||||
const int64_t ss0 = nb00 / ts0;
|
||||
const int64_t ss1 = nb01 / ts0;
|
||||
const int64_t ss2 = nb02 / ts0;
|
||||
const int64_t ss3 = nb03 / ts0;
|
||||
const int64_t ds0 = nb0 / tdst;
|
||||
const int64_t ds1 = nb1 / tdst;
|
||||
const int64_t ds2 = nb2 / tdst;
|
||||
const int64_t ds3 = nb3 / tdst;
|
||||
|
||||
/*support both WARP_SIZE or WARP_32_SIZE in code
|
||||
choose by hardware for better performance
|
||||
*/
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03,
|
||||
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device);
|
||||
}
|
||||
|
||||
@@ -308,6 +308,7 @@ enum vk_device_architecture {
|
||||
AMD_RDNA1,
|
||||
AMD_RDNA2,
|
||||
AMD_RDNA3,
|
||||
INTEL_XE1,
|
||||
INTEL_XE2,
|
||||
NVIDIA_PRE_TURING,
|
||||
NVIDIA_TURING,
|
||||
@@ -365,21 +366,26 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
|
||||
bool subgroup_size_control = false;
|
||||
bool integer_dot_product = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
|
||||
subgroup_size_control = true;
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
|
||||
integer_dot_product = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!subgroup_size_control) {
|
||||
if (!subgroup_size_control || !integer_dot_product) {
|
||||
return vk_device_architecture::OTHER;
|
||||
}
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
|
||||
|
||||
props2.pNext = &subgroup_size_control_props;
|
||||
subgroup_size_control_props.pNext = &integer_dot_props;
|
||||
device.getProperties2(&props2);
|
||||
|
||||
if (subgroup_size_control_props.minSubgroupSize == 16) {
|
||||
@@ -388,6 +394,9 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
|
||||
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
|
||||
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
|
||||
return vk_device_architecture::INTEL_XE2;
|
||||
} else if (subgroup_size_control_props.minSubgroupSize == 8 &&
|
||||
integer_dot_product && integer_dot_props.integerDotProduct4x8BitPackedSignedAccelerated) {
|
||||
return vk_device_architecture::INTEL_XE1;
|
||||
}
|
||||
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
|
||||
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||
@@ -3837,7 +3846,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support) {
|
||||
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
@@ -4710,7 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
uint32_t rm_iq = 2 * rm_kq;
|
||||
|
||||
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
||||
const bool use_subgroups = device->subgroup_arithmetic;
|
||||
// Ensure a subgroup size >= 16 is available
|
||||
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
|
||||
|
||||
@@ -6361,9 +6370,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
break;
|
||||
case VK_VENDOR_ID_INTEL: {
|
||||
// Current Windows driver does not expose BF16 support.
|
||||
// We only want to use l_warptile if coopmat is available and is Xe2+
|
||||
const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
|
||||
// We only want to use l_warptile if coopmat is available
|
||||
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && device->coopmat_support) : device->coopmat_support;
|
||||
device->mul_mat_l[i] = use_l_warptile;
|
||||
device->mul_mat_id_l[i] = use_l_warptile;
|
||||
device->mul_mat_m[i] = true;
|
||||
@@ -17890,9 +17898,9 @@ static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
|
||||
// while some older hardware (ex. Arc A770) has performance regressions
|
||||
return arch == vk_device_architecture::INTEL_XE2;
|
||||
// Only allowing Xe2/Xe3 GPU and integrated Xe GPUs at the moment since older hardware (ex. Arc A770) has performance regressions.
|
||||
return (arch == vk_device_architecture::INTEL_XE2) ||
|
||||
(arch == vk_device_architecture::INTEL_XE1 && props.deviceType == vk::PhysicalDeviceType::eIntegratedGpu && driver_props.driverID == vk::DriverId::eIntelProprietaryWindows);
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
|
||||
// Workaround for AMD proprietary driver reporting support on all GPUs
|
||||
@@ -17940,6 +17948,8 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
|
||||
case 0xE20B: // B580
|
||||
case 0xE211: // Pro B60
|
||||
return 20;
|
||||
case 0xB080: // PTL Xe3 LPG 2x6 (12 subslices)
|
||||
return 12;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -28,13 +28,10 @@ vec2 cache_b_ds;
|
||||
|
||||
#include "mul_mat_vecq_funcs.glsl"
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint col, const uint b_qs_idx) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||
|
||||
// Preload data_b block
|
||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||
const uint b_block_idx_outer = b_block_idx / 4;
|
||||
const uint b_block_idx_inner = b_block_idx % 4;
|
||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||
@@ -91,35 +88,35 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
}
|
||||
}
|
||||
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
const uint col_stride = K_PER_ITER * BLOCK_SIZE;
|
||||
uint num_iters = p.ncols / col_stride;
|
||||
if (num_iters * col_stride + K_PER_ITER * tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||
uint col = tid * K_PER_ITER;
|
||||
while (num_iters >= 4) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
[[unroll]] for (uint k = 0; k < 4; ++k) {
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
}
|
||||
|
||||
num_iters -= 4;
|
||||
}
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
if (num_iters >= 2) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
col += col_stride;
|
||||
num_iters -= 2;
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
|
||||
if (num_iters > 0) {
|
||||
iter(temp, first_row, num_rows, col, b_qs_idx);
|
||||
}
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
|
||||
@@ -42,7 +42,7 @@ float op_leaky_relu(float x) {
|
||||
}
|
||||
|
||||
float op_step(float x) {
|
||||
return x >= 0.0f ? 1.0f : 0.0f;
|
||||
return x > 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
float op_tanh(float x) {
|
||||
|
||||
@@ -1 +1 @@
|
||||
707321c4cf6d21cb4bc831aa8b687dbf01a521ce
|
||||
eced84c86f8b012c752c016f7fe789adea168e1e
|
||||
|
||||
@@ -302,9 +302,9 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama)
|
||||
llama_build_and_test(test-alloc.cpp)
|
||||
target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
|
||||
llama_build(export-graph-ops.cpp)
|
||||
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
llama_build(test-export-graph-ops.cpp)
|
||||
target_include_directories(test-export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
|
||||
if (TARGET gguf-model-data)
|
||||
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
|
||||
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
|
||||
target_link_libraries(test-export-graph-ops PRIVATE gguf-model-data)
|
||||
target_compile_definitions(test-export-graph-ops PRIVATE LLAMA_HF_FETCH)
|
||||
endif()
|
||||
|
||||
@@ -9943,7 +9943,7 @@ static void usage(char ** argv) {
|
||||
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
|
||||
printf(" --list-ops lists all available GGML operations\n");
|
||||
printf(" --show-coverage shows test coverage\n");
|
||||
printf(" --test-file reads test operators from a test file generated by llama-export-graph-ops\n");
|
||||
printf(" --test-file reads test operators from a test file generated by test-export-graph-ops\n");
|
||||
printf(" -j <n> runs tests using <n> parallel worker threads (default: 1, test mode only)\n");
|
||||
}
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ int main(int argc, char ** argv) {
|
||||
output_path = args[i + 1];
|
||||
i++;
|
||||
} else if (args[i] == "--no-common") {
|
||||
use_common = true;
|
||||
use_common = false;
|
||||
} else if (tmpl_path.empty()) {
|
||||
tmpl_path = args[i];
|
||||
} else {
|
||||
|
||||
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
#else
|
||||
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
|
||||
LOG_ERR("test-export-graph-ops compiled without HF fetch support\n");
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
set(TARGET rpc-server)
|
||||
set(TARGET ggml-rpc-server)
|
||||
add_executable(${TARGET} rpc-server.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
+17
-17
@@ -4,8 +4,8 @@
|
||||
> This example and the RPC backend are currently in a proof-of-concept development stage. As such, the functionality is fragile and
|
||||
> insecure. **Never run the RPC server on an open network or in a sensitive environment!**
|
||||
|
||||
The `rpc-server` allows exposing `ggml` devices on a remote host.
|
||||
The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them.
|
||||
The `ggml-rpc-server` allows exposing `ggml` devices on a remote host.
|
||||
The RPC backend communicates with one or several instances of `ggml-rpc-server` and offloads computations to them.
|
||||
This can be used for distributed LLM inference with `llama.cpp` in the following way:
|
||||
|
||||
```mermaid
|
||||
@@ -14,15 +14,15 @@ flowchart TD
|
||||
rpcb<-->|TCP|srvb
|
||||
rpcb<-.->|TCP|srvn
|
||||
subgraph hostn[Host N]
|
||||
srvn[rpc-server]<-.->dev4["CUDA0"]
|
||||
srvn[rpc-server]<-.->dev5["CPU"]
|
||||
srvn[ggml-rpc-server]<-.->dev4["CUDA0"]
|
||||
srvn[ggml-rpc-server]<-.->dev5["CPU"]
|
||||
end
|
||||
subgraph hostb[Host B]
|
||||
srvb[rpc-server]<-->dev3["Metal"]
|
||||
srvb[ggml-rpc-server]<-->dev3["Metal"]
|
||||
end
|
||||
subgraph hosta[Host A]
|
||||
srva[rpc-server]<-->dev["CUDA0"]
|
||||
srva[rpc-server]<-->dev2["CUDA1"]
|
||||
srva[ggml-rpc-server]<-->dev["CUDA0"]
|
||||
srva[ggml-rpc-server]<-->dev2["CUDA1"]
|
||||
end
|
||||
subgraph host[Main Host]
|
||||
local["Local devices"]<-->ggml[llama-cli]
|
||||
@@ -33,7 +33,7 @@ flowchart TD
|
||||
class local,dev,dev2,dev3,dev4,dev5 devcls
|
||||
```
|
||||
|
||||
By default, `rpc-server` exposes all available accelerator devices on the host.
|
||||
By default, `ggml-rpc-server` exposes all available accelerator devices on the host.
|
||||
If there are no accelerators, it exposes a single `CPU` device.
|
||||
|
||||
## Usage
|
||||
@@ -41,7 +41,7 @@ If there are no accelerators, it exposes a single `CPU` device.
|
||||
### Remote hosts
|
||||
|
||||
On each remote host, build the backends for each accelerator by adding `-DGGML_RPC=ON` to the build options.
|
||||
For example, to build the `rpc-server` with support for CUDA accelerators:
|
||||
For example, to build the `ggml-rpc-server` with support for CUDA accelerators:
|
||||
|
||||
```bash
|
||||
mkdir build-rpc-cuda
|
||||
@@ -50,10 +50,10 @@ cmake .. -DGGML_CUDA=ON -DGGML_RPC=ON
|
||||
cmake --build . --config Release
|
||||
```
|
||||
|
||||
When started, the `rpc-server` will detect and expose all available `CUDA` devices:
|
||||
When started, the `ggml-rpc-server` will detect and expose all available `CUDA` devices:
|
||||
|
||||
```bash
|
||||
$ bin/rpc-server
|
||||
$ bin/ggml-rpc-server
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
|
||||
ggml_cuda_init: found 1 CUDA devices:
|
||||
@@ -67,14 +67,14 @@ Devices:
|
||||
|
||||
You can control the set of exposed CUDA devices with the `CUDA_VISIBLE_DEVICES` environment variable or the `--device` command line option. The following two commands have the same effect:
|
||||
```bash
|
||||
$ CUDA_VISIBLE_DEVICES=0 bin/rpc-server -p 50052
|
||||
$ bin/rpc-server --device CUDA0 -p 50052
|
||||
$ CUDA_VISIBLE_DEVICES=0 bin/ggml-rpc-server -p 50052
|
||||
$ bin/ggml-rpc-server --device CUDA0 -p 50052
|
||||
```
|
||||
|
||||
### Main host
|
||||
|
||||
On the main host build `llama.cpp` with the backends for the local devices and add `-DGGML_RPC=ON` to the build options.
|
||||
Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `rpc-server`:
|
||||
Finally, when running `llama-cli` or `llama-server`, use the `--rpc` option to specify the host and port of each `ggml-rpc-server`:
|
||||
|
||||
```bash
|
||||
$ llama-cli -hf ggml-org/gemma-3-1b-it-GGUF -ngl 99 --rpc 192.168.88.10:50052,192.168.88.11:50052
|
||||
@@ -90,7 +90,7 @@ This can speed up model loading significantly, especially when using large model
|
||||
To enable the cache, use the `-c` option:
|
||||
|
||||
```bash
|
||||
$ bin/rpc-server -c
|
||||
$ bin/ggml-rpc-server -c
|
||||
```
|
||||
|
||||
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
|
||||
@@ -103,8 +103,8 @@ RDMA is enabled by default when `libibverbs` is found at build time.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `rpc-server`:
|
||||
Use the `GGML_RPC_DEBUG` environment variable to enable debug messages from `ggml-rpc-server`:
|
||||
```bash
|
||||
$ GGML_RPC_DEBUG=1 bin/rpc-server
|
||||
$ GGML_RPC_DEBUG=1 bin/ggml-rpc-server
|
||||
```
|
||||
|
||||
|
||||
+1
-1
@@ -33,7 +33,7 @@
|
||||
|
||||
{#if !readonly && onRemove}
|
||||
<div
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-hover:opacity-100"
|
||||
class="absolute top-10 right-2 flex items-center justify-center opacity-0 transition-opacity group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={X} tooltip="Remove" stopPropagationOnClick onclick={() => onRemove?.()} />
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -56,7 +56,7 @@
|
||||
<div class="relative flex h-6 items-center justify-between">
|
||||
<div class="right-0 flex items-center gap-2 opacity-100 transition-opacity">
|
||||
<div
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-hover:opacity-100"
|
||||
class="pointer-events-auto inset-0 flex items-center gap-1 opacity-0 transition-all duration-150 group-focus-within:opacity-100 group-hover:opacity-100"
|
||||
>
|
||||
<ActionIcon icon={Edit} tooltip="Edit" onclick={editCtx.handleEdit} />
|
||||
<ActionIcon icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
|
||||
+56
-81
@@ -39,7 +39,6 @@
|
||||
depth = 0
|
||||
}: Props = $props();
|
||||
|
||||
let renderActionsDropdown = $state(false);
|
||||
let dropdownOpen = $state(false);
|
||||
|
||||
let isLoading = $derived(getAllLoadingChats().includes(conversation.id));
|
||||
@@ -71,26 +70,10 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseOver() {
|
||||
renderActionsDropdown = true;
|
||||
}
|
||||
|
||||
function handleSelect() {
|
||||
onSelect?.(conversation.id);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!dropdownOpen) {
|
||||
renderActionsDropdown = false;
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
document.addEventListener('edit-active-conversation', handleGlobalEditEvent as EventListener);
|
||||
|
||||
@@ -103,23 +86,19 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<!-- svelte-ignore a11y_mouse_events_have_key_events -->
|
||||
<button
|
||||
class="group flex min-h-9 w-full cursor-pointer items-center justify-between space-x-3 rounded-lg py-1.5 text-left transition-colors hover:bg-foreground/10 {isActive
|
||||
<div
|
||||
class="conversation-item group relative flex min-h-9 w-full items-center justify-between space-x-3 rounded-lg py-1.5 transition-colors hover:bg-foreground/10 {isActive
|
||||
? 'bg-foreground/5 text-accent-foreground'
|
||||
: ''} px-3"
|
||||
onclick={handleSelect}
|
||||
onmouseover={handleMouseOver}
|
||||
onmouseleave={handleMouseLeave}
|
||||
onfocusin={handleMouseOver}
|
||||
onfocusout={(e) => {
|
||||
if (!e.currentTarget.contains(e.relatedTarget as Node | null)) {
|
||||
handleMouseLeave();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<button
|
||||
class="absolute inset-0 z-0 cursor-pointer rounded-lg focus:outline-none focus-visible:ring-2 focus-visible:ring-ring"
|
||||
onclick={handleSelect}
|
||||
aria-label={conversation.name}
|
||||
>
|
||||
</button>
|
||||
<div
|
||||
class="flex min-w-0 flex-1 items-center gap-2"
|
||||
class="pointer-events-none relative z-10 flex min-w-0 flex-1 items-center gap-2"
|
||||
style:padding-left="{depth * FORK_TREE_DEPTH_PADDING}px"
|
||||
>
|
||||
{#if depth > 0}
|
||||
@@ -130,7 +109,7 @@
|
||||
<a
|
||||
{...props}
|
||||
href={RouterService.chat(conversation.forkedFromConversationId)}
|
||||
class="flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
class="pointer-events-auto flex shrink-0 items-center text-muted-foreground transition-colors hover:text-foreground"
|
||||
>
|
||||
<GitBranch class="h-3.5 w-3.5" />
|
||||
</a>
|
||||
@@ -146,18 +125,15 @@
|
||||
{#if isLoading}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<div
|
||||
class="stop-button flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
<button
|
||||
class="stop-button pointer-events-auto flex h-4 w-4 shrink-0 cursor-pointer items-center justify-center rounded text-muted-foreground transition-colors hover:text-foreground"
|
||||
onclick={handleStop}
|
||||
onkeydown={(e) => e.key === 'Enter' && handleStop(e)}
|
||||
role="button"
|
||||
tabindex="0"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<Loader2 class="loading-icon h-3.5 w-3.5 animate-spin" />
|
||||
|
||||
<Square class="stop-icon hidden h-3 w-3 fill-current text-destructive" />
|
||||
</div>
|
||||
</button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
@@ -169,52 +145,50 @@
|
||||
<TruncatedText text={conversation.name} class="text-sm font-medium" showTooltip={false} />
|
||||
</div>
|
||||
|
||||
{#if renderActionsDropdown}
|
||||
<div class="actions flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
<div class="actions pointer-events-auto relative z-20 flex items-center">
|
||||
<DropdownMenuActions
|
||||
triggerIcon={MoreHorizontal}
|
||||
triggerTooltip="More actions"
|
||||
bind:open={dropdownOpen}
|
||||
actions={[
|
||||
{
|
||||
icon: conversation.pinned ? PinOff : Pin,
|
||||
label: conversation.pinned ? 'Unpin' : 'Pin',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
handleTogglePin();
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</button>
|
||||
},
|
||||
{
|
||||
icon: Pencil,
|
||||
label: 'Edit',
|
||||
onclick: handleEdit,
|
||||
shortcut: ['shift', 'cmd', 'e']
|
||||
},
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
shortcut: ['shift', 'cmd', 's']
|
||||
},
|
||||
{
|
||||
icon: Trash2,
|
||||
label: 'Delete',
|
||||
onclick: handleDelete,
|
||||
variant: 'destructive',
|
||||
shortcut: ['shift', 'cmd', 'd'],
|
||||
separator: true
|
||||
}
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
button {
|
||||
.conversation-item {
|
||||
:global([data-slot='dropdown-menu-trigger']:not([data-state='open'])) {
|
||||
opacity: 0;
|
||||
}
|
||||
@@ -239,7 +213,8 @@
|
||||
}
|
||||
}
|
||||
|
||||
&:is(:hover) .stop-button {
|
||||
&:is(:hover) .stop-button,
|
||||
&:focus-within .stop-button {
|
||||
:global(.stop-icon) {
|
||||
display: block;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user