mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 06:37:41 +02:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ac4105d68b | |||
| be4a6a63eb | |||
| 72a9269172 | |||
| 92e854ab83 | |||
| c5606364b2 | |||
| 0eb874d374 | |||
| 75ad0b23ed | |||
| c926ad0985 | |||
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 | |||
| dec5ca5577 | |||
| 9c0ac887f3 | |||
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac | |||
| 0ef6f06d55 | |||
| 52b3df0023 | |||
| 7c082bc417 | |||
| bddfd2b113 | |||
| 0d135df48c | |||
| bf533823cd | |||
| 2f89acc2bc | |||
| bfa3219177 | |||
| d6d899580d | |||
| 8a118ee86c | |||
| d789527482 | |||
| 063d9c156e | |||
| c57607016a | |||
| 4a80943174 | |||
| 84de01a1f1 | |||
| 75f460ac28 | |||
| 8452824611 | |||
| e27f308597 | |||
| 67e9fd3b74 | |||
| 796f41bedc | |||
| 37a77fb057 | |||
| f4043fec01 | |||
| f449e05537 | |||
| 2b686a9120 | |||
| 4b48a53b6c | |||
| e475fa2b5f | |||
| 175147e8f6 | |||
| fabde3bf51 | |||
| 0d2d9ccbf6 | |||
| 8c2d6f6475 | |||
| 38724ab593 | |||
| e2e7a9b2d0 | |||
| b14e3fb90c | |||
| 159d093a43 | |||
| 5fd2dc2c41 | |||
| 1868af13ac | |||
| 5bd21b8555 | |||
| 80452d65b9 | |||
| 8141e730f1 | |||
| db52540f73 | |||
| 3a3edc9ac6 | |||
| 40f3aafc45 | |||
| a6b3260a42 | |||
| 32eddaf2ea | |||
| 060ce1bf72 | |||
| d2c67959b3 | |||
| 7b6c5a2aed | |||
| fe7c8b2414 | |||
| e1efd0991d | |||
| 08023072ef | |||
| 20832179e2 | |||
| 10786217e9 | |||
| 552258c535 | |||
| 968c43891a |
@@ -13,6 +13,20 @@ ARG APP_REVISION=N/A
|
||||
# BUILD STAGE
|
||||
# Compile all binary files and libraries
|
||||
# ==============================================================================
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${CANN_BASE_IMAGE} AS build
|
||||
|
||||
# -- Install build dependencies --
|
||||
@@ -26,6 +40,8 @@ WORKDIR /app
|
||||
# -- Copy project files --
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
# -- Set CANN environment variables (required for compilation) --
|
||||
# Using ENV instead of `source` allows environment variables to persist across the entire image layer
|
||||
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
ARG TARGETARCH
|
||||
@@ -16,6 +30,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
|
||||
else \
|
||||
|
||||
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
ARG GCC_VERSION
|
||||
@@ -26,6 +40,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
|
||||
@@ -5,6 +5,20 @@ ARG APP_REVISION=N/A
|
||||
|
||||
## Build Image
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/intel/deep-learning-essentials:$ONEAPI_VERSION AS build
|
||||
|
||||
ARG GGML_SYCL_F16=ON
|
||||
@@ -22,6 +36,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
|
||||
echo "GGML_SYCL_F16 is set" \
|
||||
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON" \
|
||||
|
||||
@@ -10,6 +10,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
|
||||
|
||||
# MUSA architecture to build for (defaults to all supported archs)
|
||||
@@ -29,6 +43,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
|
||||
@@ -22,6 +22,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
## Build Image
|
||||
FROM docker.io/ubuntu:${UBUNTU_VERSION} AS build
|
||||
|
||||
@@ -69,6 +83,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
# Build Stage
|
||||
RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \
|
||||
cmake -B build/ReleaseOV -G Ninja \
|
||||
|
||||
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
### Build image
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
|
||||
@@ -38,6 +52,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build \
|
||||
-DGGML_HIP=ON \
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
# Install build tools
|
||||
@@ -17,6 +31,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
|
||||
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
@@ -14,6 +28,8 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
|
||||
build*/
|
||||
|
||||
tools/ui/node_modules/
|
||||
|
||||
models/*
|
||||
|
||||
/llama-cli
|
||||
|
||||
@@ -58,6 +58,13 @@ jobs:
|
||||
git tag ${{ steps.srctag.outputs.name }} || exit 0
|
||||
git push origin ${{ steps.srctag.outputs.name }} || exit 0
|
||||
|
||||
build_ui:
|
||||
name: Build UI
|
||||
needs: create_tag
|
||||
uses: ./.github/workflows/ui-build.yml
|
||||
with:
|
||||
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
prepare_matrices:
|
||||
name: Prepare Docker matrices
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -79,7 +86,7 @@ jobs:
|
||||
[
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
@@ -135,7 +142,7 @@ jobs:
|
||||
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Registry
|
||||
needs: [prepare_matrices, create_tag]
|
||||
needs: [prepare_matrices, create_tag, build_ui]
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
strategy:
|
||||
@@ -150,6 +157,13 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
- name: Download prebuilt UI
|
||||
if: ${{ matrix.config.prebuilt_ui == true }}
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
name: ui-build
|
||||
path: tools/ui/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
|
||||
|
||||
@@ -1627,6 +1627,7 @@ jobs:
|
||||
**Windows:**
|
||||
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
|
||||
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
|
||||
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
|
||||
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
|
||||
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
|
||||
|
||||
@@ -25,13 +25,3 @@ Commits:
|
||||
- Do not explicitly set the git author in commits - rely on the default git config
|
||||
- Always use `--no-gpg-sign` when committing
|
||||
- Never `git push` without explicit confirmation from the user
|
||||
|
||||
Resources (read on demand):
|
||||
- [CONTRIBUTING.md](CONTRIBUTING.md)
|
||||
- [Build documentation](docs/build.md)
|
||||
- [Server usage documentation](tools/server/README.md)
|
||||
- [Server development documentation](tools/server/README-dev.md)
|
||||
- [PEG parser](docs/development/parsing.md)
|
||||
- [Auto parser](docs/autoparser.md)
|
||||
- [Jinja engine](common/jinja/README.md)
|
||||
- [PR template](.github/pull_request_template.md)
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
+101
-134
@@ -17,6 +17,7 @@
|
||||
# define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <shellapi.h>
|
||||
#endif
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
@@ -285,67 +286,25 @@ static std::string clean_file_name(const std::string & fname) {
|
||||
return clean_fname;
|
||||
}
|
||||
|
||||
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
|
||||
GGML_ASSERT(!params.model.hf_repo.empty());
|
||||
|
||||
// the returned hf_repo is without tag
|
||||
auto [hf_repo, hf_tag] = common_download_split_repo_tag(params.model.hf_repo);
|
||||
|
||||
// "latest" tag (default if not specified) is translated to "default" preset
|
||||
if (hf_tag == "latest") {
|
||||
hf_tag = "default";
|
||||
}
|
||||
|
||||
std::string model_endpoint = common_get_model_endpoint();
|
||||
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
|
||||
|
||||
// prepare local path for caching
|
||||
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
|
||||
auto preset_path = fs_get_cache_file(preset_fname);
|
||||
common_download_opts opts;
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
|
||||
LOG_TRC("%s: looking for remote preset at %s\n", __func__, preset_url.c_str());
|
||||
const int status = common_download_file_single(preset_url, preset_path, opts);
|
||||
const bool has_preset = status >= 200 && status < 400;
|
||||
|
||||
// remote preset is optional, so we don't error out if not found
|
||||
if (has_preset) {
|
||||
LOG_TRC("%s: applying remote preset from %s\n", __func__, preset_url.c_str());
|
||||
common_preset_context ctx(ex, /* only_remote_allowed */ true);
|
||||
common_preset global;
|
||||
auto remote_presets = ctx.load_from_ini(preset_path, global);
|
||||
remote_presets = ctx.cascade(global, remote_presets);
|
||||
if (remote_presets.find(hf_tag) != remote_presets.end()) {
|
||||
common_preset preset = remote_presets.at(hf_tag);
|
||||
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
|
||||
preset.apply_to_params(params);
|
||||
} else {
|
||||
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(hf_tag) + "] section");
|
||||
}
|
||||
} else {
|
||||
LOG_TRC("%s: no remote preset found, skipping\n", __func__);
|
||||
}
|
||||
|
||||
return has_preset;
|
||||
}
|
||||
|
||||
struct handle_model_result {
|
||||
bool found_mmproj = false;
|
||||
common_params_model mmproj;
|
||||
|
||||
bool found_mtp = false;
|
||||
common_params_model mtp;
|
||||
|
||||
bool found_preset = false;
|
||||
std::string preset_path;
|
||||
};
|
||||
|
||||
static handle_model_result common_params_handle_model(struct common_params_model & model,
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
model.name = model.docker_repo;
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
@@ -355,11 +314,16 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
common_download_opts hf_opts = opts;
|
||||
auto download_result = common_download_model(model, hf_opts);
|
||||
|
||||
if (!download_result.preset_path.empty()) {
|
||||
result.found_preset = true;
|
||||
result.preset_path = download_result.preset_path;
|
||||
return result; // skip everything else if preset.ini is used
|
||||
}
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
}
|
||||
|
||||
model.name = model.hf_repo;
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
@@ -434,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -445,6 +409,11 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
@@ -454,6 +423,17 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
|
||||
try {
|
||||
auto res = common_params_handle_model(params.model, opts);
|
||||
if (res.found_preset) {
|
||||
if (!params.models_preset.empty()) {
|
||||
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
|
||||
}
|
||||
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
|
||||
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
|
||||
params.models_preset = res.preset_path;
|
||||
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
|
||||
return true;
|
||||
}
|
||||
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
|
||||
@@ -601,30 +581,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
// parse the first time to get -hf option (used for remote preset)
|
||||
parse_cli_args();
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
// maybe handle remote preset
|
||||
if (!params.model.hf_repo.empty() && !skip_model_download) {
|
||||
std::string cli_hf_repo = params.model.hf_repo;
|
||||
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
|
||||
|
||||
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
|
||||
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
|
||||
std::string preset_hf_repo = params.model.hf_repo;
|
||||
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
|
||||
|
||||
if (has_preset) {
|
||||
// re-parse CLI args to override preset values
|
||||
parse_cli_args();
|
||||
}
|
||||
|
||||
// preserve hf_repo from preset if needed
|
||||
if (preset_has_hf_repo) {
|
||||
params.model.hf_repo = preset_hf_repo;
|
||||
}
|
||||
}
|
||||
|
||||
postprocess_cpu_params(params.cpuparams, nullptr);
|
||||
postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams);
|
||||
|
||||
@@ -635,15 +591,23 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// handle model and download
|
||||
if (!skip_model_download) {
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
}
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (params.escape) {
|
||||
@@ -937,7 +901,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
struct utf8_argv {
|
||||
std::vector<std::string> buf;
|
||||
std::vector<char*> ptrs;
|
||||
};
|
||||
|
||||
static utf8_argv make_utf8_argv() {
|
||||
utf8_argv out;
|
||||
int wargc = 0;
|
||||
LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc);
|
||||
if (!wargv) return out;
|
||||
|
||||
out.buf.reserve(wargc);
|
||||
for (int i = 0; i < wargc; ++i) {
|
||||
int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr);
|
||||
if (n <= 0) { out.buf.emplace_back(); continue; }
|
||||
auto& s = out.buf.emplace_back();
|
||||
s.resize(static_cast<size_t>(n - 1));
|
||||
(void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr);
|
||||
}
|
||||
LocalFree(wargv);
|
||||
|
||||
out.ptrs.reserve(out.buf.size() + 1);
|
||||
for (auto& s : out.buf) out.ptrs.push_back(s.data());
|
||||
out.ptrs.push_back(nullptr);
|
||||
return out;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
#ifdef _WIN32
|
||||
auto utf8 = make_utf8_argv();
|
||||
// repair argv only when it matches the process command line
|
||||
if (static_cast<int>(utf8.buf.size()) == argc) {
|
||||
argv = utf8.ptrs.data();
|
||||
}
|
||||
#endif
|
||||
|
||||
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
|
||||
const common_params params_org = ctx_arg.params; // the example can modify the default params
|
||||
|
||||
@@ -2874,62 +2875,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.api_prefix = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
|
||||
// Deprecated: use --ui-config instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-config"}, "JSON",
|
||||
"[DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = value;
|
||||
params.webui_config_json = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-config"}, "JSON",
|
||||
{"--ui-config", "--webui-config"}, "JSON",
|
||||
"JSON that provides default UI settings (overrides UI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = value;
|
||||
params.webui_config_json = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG"));
|
||||
|
||||
// Deprecated: use --ui-config-file instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-config-file"}, "PATH",
|
||||
"[DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = read_file(value);
|
||||
params.webui_config_json = params.ui_config_json;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-config-file"}, "PATH",
|
||||
{"--ui-config-file", "--webui-config-file"}, "PATH",
|
||||
"JSON file that provides default UI settings (overrides UI defaults)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.ui_config_json = read_file(value);
|
||||
params.webui_config_json = params.ui_config_json;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG_FILE"));
|
||||
|
||||
// Deprecated: use --ui-mcp-proxy instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui-mcp-proxy"},
|
||||
{"--no-webui-mcp-proxy"},
|
||||
"[DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui_mcp_proxy = value;
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--ui-mcp-proxy"},
|
||||
{"--no-ui-mcp-proxy"},
|
||||
{"--ui-mcp-proxy", "--webui-mcp-proxy"},
|
||||
{"--no-ui-mcp-proxy", "--no-webui-mcp-proxy"},
|
||||
"experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui_mcp_proxy = value;
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
@@ -2941,24 +2906,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
// Deprecated: use --ui/--no-ui instead (kept for backward compat)
|
||||
add_opt(common_arg(
|
||||
{"--webui"},
|
||||
{"--no-webui"},
|
||||
"[DEPRECATED: use --ui/--no-ui] whether to enable the Web UI",
|
||||
{"-ag", "--agent"},
|
||||
{"-no-ag", "--no-agent"},
|
||||
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
|
||||
[](common_params & params, bool value) {
|
||||
params.ui = value;
|
||||
params.webui = value;
|
||||
if (value) {
|
||||
params.server_tools = {"all"};
|
||||
params.ui_mcp_proxy = true;
|
||||
} else {
|
||||
params.server_tools.clear();
|
||||
params.ui_mcp_proxy = false;
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI"));
|
||||
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_AGENT"));
|
||||
add_opt(common_arg(
|
||||
{"--ui"},
|
||||
{"--no-ui"},
|
||||
{"--ui", "--webui"},
|
||||
{"--no-ui", "--no-webui"},
|
||||
string_format("whether to enable the Web UI (default: %s)", params.ui ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.ui = value;
|
||||
params.webui = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI"));
|
||||
add_opt(common_arg(
|
||||
@@ -2989,7 +2956,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
|
||||
add_opt(common_arg(
|
||||
{"--api-key-file"}, "FNAME",
|
||||
"path to file containing API keys (default: none)",
|
||||
"path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
std::ifstream key_file(value);
|
||||
if (!key_file) {
|
||||
@@ -2997,7 +2964,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
std::string key;
|
||||
while (std::getline(key_file, key)) {
|
||||
if (!key.empty()) {
|
||||
if (!key.empty() && key[0] != '#') {
|
||||
params.api_keys.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
+10
-1
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -129,11 +130,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
const common_params_handle_models_params & handle_params);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
p.ac(p.tool_arg_string_value(until_suffix) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
|
||||
(p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)))));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
+103
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+15
-1
@@ -1074,6 +1074,18 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
|
||||
return files;
|
||||
}
|
||||
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) {
|
||||
#ifdef _WIN32
|
||||
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
|
||||
if (!wlen) { return std::ifstream(); }
|
||||
std::vector<wchar_t> wfname(wlen);
|
||||
(void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen);
|
||||
return std::ifstream(wfname.data(), mode);
|
||||
#else
|
||||
return std::ifstream(fname, mode);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
@@ -2034,7 +2046,7 @@ bool common_prompt_batch_decode(
|
||||
}
|
||||
|
||||
size_t common_prompt_checkpoint::size() const {
|
||||
return data_tgt.size() + data_dft.size();
|
||||
return data_tgt.size() + data_dft.size() + data_spec.size();
|
||||
}
|
||||
|
||||
bool common_prompt_checkpoint::empty() const {
|
||||
@@ -2049,6 +2061,7 @@ void common_prompt_checkpoint::clear() {
|
||||
|
||||
data_tgt.clear();
|
||||
data_dft.clear();
|
||||
data_spec.clear();
|
||||
}
|
||||
|
||||
void common_prompt_checkpoint::update_pos(
|
||||
@@ -2138,4 +2151,5 @@ void common_prompt_checkpoint::clear_tgt() {
|
||||
|
||||
void common_prompt_checkpoint::clear_dft() {
|
||||
data_dft.clear();
|
||||
data_spec.clear();
|
||||
}
|
||||
|
||||
+24
-13
@@ -295,7 +295,16 @@ struct common_params_model {
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
|
||||
std::string get_name() {
|
||||
if (!hf_repo.empty()) {
|
||||
return hf_repo;
|
||||
}
|
||||
if (!docker_repo.empty()) {
|
||||
return docker_repo;
|
||||
}
|
||||
return path;
|
||||
}
|
||||
};
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
@@ -363,7 +372,7 @@ struct common_params_speculative {
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
|
||||
});
|
||||
|
||||
return needs_rs_seq ? draft.n_max : 0u;
|
||||
@@ -600,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
@@ -624,12 +633,6 @@ struct common_params {
|
||||
|
||||
// UI configs
|
||||
bool ui = true;
|
||||
|
||||
// Deprecated: use ui, ui_mcp_proxy, ui_config_json instead
|
||||
bool webui = ui;
|
||||
bool webui_mcp_proxy = false;
|
||||
std::string webui_config_json;
|
||||
|
||||
bool ui_mcp_proxy = false;
|
||||
std::string ui_config_json;
|
||||
|
||||
@@ -642,10 +645,11 @@ struct common_params {
|
||||
std::vector<std::string> server_tools;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
int models_max = 4; // maximum number of models to load simultaneously
|
||||
bool models_autoload = true; // automatically load models when requested via the router server
|
||||
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)
|
||||
|
||||
bool log_json = false;
|
||||
|
||||
@@ -847,6 +851,9 @@ struct common_file_info {
|
||||
};
|
||||
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
|
||||
|
||||
// fs open, also handle UTF8 on Windows
|
||||
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode);
|
||||
|
||||
//
|
||||
// TTY utils
|
||||
//
|
||||
@@ -1064,6 +1071,10 @@ struct common_prompt_checkpoint {
|
||||
std::vector<uint8_t> data_tgt;
|
||||
std::vector<uint8_t> data_dft;
|
||||
|
||||
// (optional) speculative-decoding implementation state stashed with the checkpoint
|
||||
// (e.g. eagle3's deferred-boundary g_embd row)
|
||||
std::vector<uint8_t> data_spec;
|
||||
|
||||
size_t size() const;
|
||||
|
||||
bool empty() const;
|
||||
|
||||
+38
-17
@@ -696,6 +696,7 @@ struct hf_plan {
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
hf_cache::hf_file preset; // if set, only this file is downloaded
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
@@ -717,6 +718,14 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
return plan;
|
||||
}
|
||||
|
||||
// if preset.ini exists in the repo root, download only that file
|
||||
for (const auto & f : all) {
|
||||
if (f.path == "preset.ini") {
|
||||
plan.preset = f;
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
|
||||
hf_cache::hf_file primary;
|
||||
|
||||
if (!model.hf_file.empty()) {
|
||||
@@ -790,18 +799,25 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
|
||||
}
|
||||
if (!hf.mtp.path.empty()) {
|
||||
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
|
||||
}
|
||||
if (!hf.mtp.path.empty()) {
|
||||
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
|
||||
}
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
tasks = get_url_tasks(model);
|
||||
@@ -835,17 +851,22 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
}
|
||||
|
||||
if (is_hf) {
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.primary.final_path;
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini is used, do not set other paths
|
||||
result.preset_path = hf_cache::finalize_file(hf.preset);
|
||||
} else {
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.primary.final_path;
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
|
||||
if (!hf.mtp.path.empty()) {
|
||||
result.mtp_path = hf_cache::finalize_file(hf.mtp);
|
||||
if (!hf.mtp.path.empty()) {
|
||||
result.mtp_path = hf_cache::finalize_file(hf.mtp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
|
||||
@@ -55,6 +55,7 @@ struct common_download_opts {
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
@@ -63,6 +64,7 @@ struct common_download_model_result {
|
||||
std::string model_path;
|
||||
std::string mmproj_path;
|
||||
std::string mtp_path;
|
||||
std::string preset_path;
|
||||
};
|
||||
|
||||
// throw if the file is missing or invalid (e.g. ETag check failed)
|
||||
|
||||
+89
-46
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
|
||||
const size_t expected_count = this_args.size();
|
||||
const size_t input_count = args.count();
|
||||
|
||||
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this_args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this_args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
value macro_statement::execute_impl(context & ctx) {
|
||||
if (!is_stmt<identifier>(this->name)) {
|
||||
throw std::runtime_error("Macro name must be an identifier");
|
||||
}
|
||||
std::string name = cast_stmt<identifier>(this->name)->val;
|
||||
|
||||
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
|
||||
size_t expected_count = this->args.size();
|
||||
size_t input_count = args.count();
|
||||
const func_handler func = [this, name](const func_args & args) -> value {
|
||||
context macro_ctx(args.ctx); // new scope for macro execution
|
||||
|
||||
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
context macro_ctx(ctx); // new scope for macro execution
|
||||
|
||||
// bind parameters
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this->args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this->args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
bind_parameters(name, this->args, args, macro_ctx);
|
||||
|
||||
// execute macro body
|
||||
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
|
||||
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
value call_statement::execute_impl(context & ctx) {
|
||||
auto call_expr = cast_stmt<call_expression>(this->call);
|
||||
if (!call_expr) {
|
||||
throw std::runtime_error("Call statement requires a valid call expression");
|
||||
}
|
||||
|
||||
value callee_val = call_expr->callee->execute(ctx);
|
||||
if (!is_val<value_func>(callee_val)) {
|
||||
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
|
||||
}
|
||||
auto * callee_func = cast_val<value_func>(callee_val);
|
||||
|
||||
context caller_ctx(ctx); // new scope for caller execution
|
||||
|
||||
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
|
||||
context block_ctx(caller_ctx); // new scope for block execution
|
||||
|
||||
bind_parameters("caller", this->caller_args, args, block_ctx);
|
||||
|
||||
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
|
||||
auto res = exec_statements(this->body, block_ctx);
|
||||
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
|
||||
return res;
|
||||
};
|
||||
|
||||
context call_ctx(ctx);
|
||||
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
|
||||
|
||||
func_args args(call_ctx);
|
||||
|
||||
for (const auto & arg_expr : call_expr->args) {
|
||||
auto arg_val = arg_expr->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.push_back(arg_val);
|
||||
}
|
||||
|
||||
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
|
||||
return callee_func->invoke(args);
|
||||
}
|
||||
|
||||
value member_expression::execute_impl(context & ctx) {
|
||||
value object = this->object->execute(ctx);
|
||||
|
||||
|
||||
@@ -552,6 +552,7 @@ struct call_statement : public statement {
|
||||
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
|
||||
@@ -233,27 +233,27 @@ struct BuiltinRule {
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
||||
{"boolean", {"(\"true\" | \"false\")", {}}},
|
||||
{"decimal-part", {"[0-9]{1,16}", {}}},
|
||||
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
|
||||
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
|
||||
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
||||
{"null", {"\"null\" space", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
|
||||
{"null", {"\"null\"", {}}},
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
||||
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
||||
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
@@ -551,16 +551,16 @@ private:
|
||||
}
|
||||
return join_seq();
|
||||
};
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
|
||||
}
|
||||
|
||||
/*
|
||||
Returns a rule that matches a JSON string that is none of the provided strings
|
||||
|
||||
not_strings({"a"})
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
not_strings({"and", "also"})
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
*/
|
||||
std::string _not_strings(const std::vector<std::string> & strings) {
|
||||
|
||||
@@ -619,7 +619,7 @@ private:
|
||||
if (!trie.is_end_of_string) {
|
||||
out << "?";
|
||||
}
|
||||
out << " [\"] space";
|
||||
out << " [\"]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
@@ -725,7 +725,7 @@ private:
|
||||
rule += " )?";
|
||||
}
|
||||
|
||||
rule += " \"}\" space";
|
||||
rule += " space \"}\"";
|
||||
|
||||
return rule;
|
||||
}
|
||||
@@ -858,14 +858,14 @@ public:
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
}
|
||||
if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||
}
|
||||
if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
@@ -933,7 +933,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
@@ -948,7 +948,7 @@ public:
|
||||
}
|
||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||
}
|
||||
rule += " \"]\" space";
|
||||
rule += " space \"]\"";
|
||||
return _add_rule(rule_name, rule);
|
||||
}
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
@@ -956,7 +956,7 @@ public:
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
@@ -972,7 +972,7 @@ public:
|
||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
|
||||
}
|
||||
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
@@ -990,7 +990,7 @@ public:
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
build_min_max_int(min_value, max_value, out);
|
||||
out << ") space";
|
||||
out << ")";
|
||||
return _add_rule(rule_name, out.str());
|
||||
}
|
||||
if (schema.empty() || schema_type == "object") {
|
||||
|
||||
+202
-89
@@ -6,13 +6,14 @@
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
// Trick to catch missing branches
|
||||
template <typename T>
|
||||
@@ -88,40 +89,7 @@ struct trie {
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
}
|
||||
out.emplace_back(prefix_and_next{prefix, chars});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
prefix.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
size_t create_node() {
|
||||
size_t index = nodes.size();
|
||||
nodes.emplace_back();
|
||||
@@ -153,6 +121,65 @@ struct trie {
|
||||
}
|
||||
};
|
||||
|
||||
// Aho-Corasick automaton
|
||||
struct aho_corasick {
|
||||
trie t;
|
||||
std::vector<size_t> fail; // failure links
|
||||
std::vector<size_t> order; // states in BFS order
|
||||
std::vector<bool> terminal; // match states (directly or via a suffix link)
|
||||
std::set<uint32_t> alphabet; // every character with a transition
|
||||
|
||||
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
|
||||
const auto & nodes = t.nodes;
|
||||
const size_t n = nodes.size();
|
||||
|
||||
fail.assign(n, 0);
|
||||
order.reserve(n);
|
||||
|
||||
std::deque<size_t> queue{ 0 };
|
||||
while (!queue.empty()) {
|
||||
size_t u = queue.front();
|
||||
queue.pop_front();
|
||||
order.push_back(u);
|
||||
for (const auto & [ch, v] : nodes[u].children) {
|
||||
if (u != 0) {
|
||||
size_t f = fail[u];
|
||||
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
|
||||
f = fail[f];
|
||||
}
|
||||
auto it = nodes[f].children.find(ch);
|
||||
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
|
||||
}
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
terminal.assign(n, false);
|
||||
for (size_t u : order) {
|
||||
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
|
||||
}
|
||||
|
||||
for (const auto & node : nodes) {
|
||||
for (const auto & [ch, v] : node.children) {
|
||||
alphabet.insert(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_states() const { return t.nodes.size(); }
|
||||
bool is_terminal(size_t s) const { return terminal[s]; }
|
||||
|
||||
// follow failure links until a transition on `ch` exists.
|
||||
size_t next(size_t state, uint32_t ch) const {
|
||||
const auto & nodes = t.nodes;
|
||||
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
|
||||
state = fail[state];
|
||||
}
|
||||
auto it = nodes[state].children.find(ch);
|
||||
return it != nodes[state].children.end() ? it->second : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
|
||||
if (pos + hex_count > str.length()) {
|
||||
return {0, 0};
|
||||
@@ -894,6 +921,10 @@ struct parser_executor {
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
std::set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
std::set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
|
||||
if (delimiters.empty()) {
|
||||
throw std::runtime_error("ac parser requires at least one delimiter");
|
||||
}
|
||||
return add(common_peg_ac_parser{p, delimiters});
|
||||
}
|
||||
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
trie matcher(strings);
|
||||
auto pieces = matcher.collect_prefix_and_next();
|
||||
|
||||
std::string pattern;
|
||||
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
|
||||
for (size_t i = 0; i < pieces.size(); ++i) {
|
||||
if (i > 0) {
|
||||
pattern += " | ";
|
||||
}
|
||||
|
||||
const auto & pre = pieces[i].prefix;
|
||||
const auto & chars = pieces[i].next_chars;
|
||||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
|
||||
pattern += pre_literal + " [^" + cls + "]";
|
||||
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
|
||||
// char, so the repetition alone cannot match a value that *ends* on a proper
|
||||
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
|
||||
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
|
||||
// values, so without this the grammar would reject input the parser accepts.
|
||||
// Allow the value to terminate on any proper prefix as an optional tail.
|
||||
// This makes the grammar a slight superset of the runtime language (a value
|
||||
// may end on the longest prefix, which greedy first-match would not itself
|
||||
// produce); harmless for constrained generation, which only needs to admit
|
||||
// every runtime-valid string.
|
||||
if (!trailing.empty()) {
|
||||
trailing += " | ";
|
||||
}
|
||||
trailing += pre_literal;
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
}
|
||||
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
std::string result = "(" + pattern + ")*";
|
||||
if (!trailing.empty()) {
|
||||
result += " (" + trailing + ")?";
|
||||
}
|
||||
return result;
|
||||
return s + "]";
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> collect_reachable_rules(
|
||||
static std::string gbnf_ac_grammar(
|
||||
const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings,
|
||||
const std::function<std::string(const std::vector<uint32_t> &,
|
||||
const std::map<size_t, std::vector<uint32_t>> &,
|
||||
const std::vector<uint32_t> &,
|
||||
const std::function<std::string(size_t)> &)> & build_rule) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
if (s == 0) {
|
||||
return prefix;
|
||||
}
|
||||
std::string num = std::to_string(s);
|
||||
num = num.size() == 1 ? ("0" + num) : num;
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> completing; // chars that complete a delimiter
|
||||
std::vector<uint32_t> specific; // chars with an explicit transition
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
completing.push_back(c);
|
||||
specific.push_back(c);
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
specific.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
|
||||
}
|
||||
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches the empty string so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & /*completing*/,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
// every state is accepting and completing chars get no
|
||||
// alternative, so a forbidden string can never be matched
|
||||
std::string rhs = "|";
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
|
||||
return rhs;
|
||||
});
|
||||
}
|
||||
|
||||
// GBNF grammar matching everything up to and including the first occurrence of
|
||||
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & completing,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
std::vector<std::string> alts;
|
||||
if (!completing.empty()) {
|
||||
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
|
||||
}
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
|
||||
}
|
||||
// every other character keeps scanning from the start state
|
||||
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
|
||||
return string_join(alts, " | ");
|
||||
});
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
) {
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::unordered_set<std::string> visited;
|
||||
std::set<std::string> reachable;
|
||||
std::set<std::string> visited;
|
||||
|
||||
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
|
||||
const auto & parser = arena.get(id);
|
||||
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
};
|
||||
|
||||
// Collect reachable rules
|
||||
std::unordered_set<std::string> reachable_rules;
|
||||
std::set<std::string> reachable_rules;
|
||||
|
||||
if (lazy) {
|
||||
// Collect rules reachable from trigger rules
|
||||
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "ac") {
|
||||
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
|
||||
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
|
||||
}
|
||||
return common_peg_ac_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["delimiters"].get<std::vector<std::string>>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+16
-3
@@ -3,8 +3,8 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
struct common_peg_ac_parser {
|
||||
common_peg_parser_id child;
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
common_peg_gbnf_parser,
|
||||
common_peg_ac_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -335,7 +341,7 @@ class common_peg_arena {
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
|
||||
// automaton of `delimiters`, matching everything up to and including the
|
||||
// first delimiter. Parsing delegates entirely to the child, which is
|
||||
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+1
-49
@@ -16,48 +16,6 @@ static std::string rm_leading_dashes(const std::string & str) {
|
||||
return str.substr(pos);
|
||||
}
|
||||
|
||||
// only allow a subset of args for remote presets for security reasons
|
||||
// do not add more args unless absolutely necessary
|
||||
// args that output to files are strictly prohibited
|
||||
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
|
||||
static const std::set<std::string> allowed_options = {
|
||||
"model-url",
|
||||
"hf-repo",
|
||||
"hf-repo-draft",
|
||||
"hf-repo-v", // vocoder
|
||||
"hf-file-v", // vocoder
|
||||
"mmproj-url",
|
||||
"pooling",
|
||||
"jinja",
|
||||
"batch-size",
|
||||
"ubatch-size",
|
||||
"cache-reuse",
|
||||
"chat-template-kwargs",
|
||||
"mmap",
|
||||
// note: sampling params are automatically allowed by default
|
||||
// negated args will be added automatically if the positive arg is specified above
|
||||
};
|
||||
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
for (const auto & it : key_to_opt) {
|
||||
const std::string & key = it.first;
|
||||
const common_arg & opt = it.second;
|
||||
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
|
||||
allowed_keys.insert(key);
|
||||
// also add variant keys (args without leading dashes and env vars)
|
||||
for (const auto & arg : opt.get_args()) {
|
||||
allowed_keys.insert(rm_leading_dashes(arg));
|
||||
}
|
||||
for (const auto & env : opt.get_env()) {
|
||||
allowed_keys.insert(env);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allowed_keys;
|
||||
}
|
||||
|
||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||
std::vector<std::string> args;
|
||||
|
||||
@@ -300,16 +258,10 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
|
||||
return value;
|
||||
}
|
||||
|
||||
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
|
||||
common_preset_context::common_preset_context(llama_example ex)
|
||||
: ctx_params(common_params_parser_init(default_params, ex)) {
|
||||
common_params_add_preset_options(ctx_params.options);
|
||||
key_to_opt = get_map_key_opt(ctx_params);
|
||||
|
||||
// setup allowed keys if only_remote_allowed is true
|
||||
if (only_remote_allowed) {
|
||||
filter_allowed_keys = true;
|
||||
allowed_keys = get_remote_preset_whitelist(key_to_opt);
|
||||
}
|
||||
}
|
||||
|
||||
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ struct common_preset_context {
|
||||
std::set<std::string> allowed_keys;
|
||||
|
||||
// if only_remote_allowed is true, only accept whitelisted keys
|
||||
common_preset_context(llama_example ex, bool only_remote_allowed = false);
|
||||
common_preset_context(llama_example ex);
|
||||
|
||||
// load presets from INI file
|
||||
common_presets load_from_ini(const std::string & path, common_preset & global) const;
|
||||
|
||||
@@ -259,6 +259,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!grmr && !grammar_str.empty()) {
|
||||
throw std::runtime_error("failed to parse grammar");
|
||||
}
|
||||
|
||||
// Compute prefill tokens from the generation prompt
|
||||
std::vector<llama_token> prefill_tokens;
|
||||
|
||||
+174
-35
@@ -161,6 +161,10 @@ struct common_speculative_impl {
|
||||
|
||||
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
|
||||
|
||||
// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
|
||||
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
|
||||
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}
|
||||
|
||||
// true if this implementation requires the target context to extract post-norm embeddings
|
||||
virtual bool need_embd() const = 0;
|
||||
|
||||
@@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
(size_t) n_embd_dec * sizeof(float));
|
||||
}
|
||||
|
||||
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
|
||||
// their single-position checkpoints drop it on restore
|
||||
bool need_boundary_stash() const {
|
||||
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
|
||||
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
|
||||
}
|
||||
|
||||
bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
|
||||
if (!need_boundary_stash()) {
|
||||
return false;
|
||||
}
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llama_pos pos = pending_pos_last[seq_id];
|
||||
const std::vector<float> & g = pending_g_last[seq_id];
|
||||
|
||||
data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
|
||||
std::memcpy(data.data(), &pos, sizeof(llama_pos));
|
||||
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
|
||||
return true;
|
||||
}
|
||||
|
||||
void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
|
||||
if (!need_boundary_stash()) {
|
||||
return;
|
||||
}
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_pos pos = -1;
|
||||
std::memcpy(&pos, data.data(), sizeof(llama_pos));
|
||||
|
||||
pending_pos_last[seq_id] = pos;
|
||||
pending_g_last[seq_id].resize(n_embd_dec);
|
||||
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
@@ -858,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool is_mem_shared = false;
|
||||
// One MTP draft driver, three modes (set once in the ctor):
|
||||
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
|
||||
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
|
||||
// neither (qwen35 / qwen35moe): a single trained MTP head.
|
||||
int32_t n_mtp_layers = 1;
|
||||
bool is_mem_shared = false; // gemma4
|
||||
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
@@ -873,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
std::vector<std::vector<float>> verify_h;
|
||||
std::vector<int32_t> verify_h_rows;
|
||||
|
||||
// Per-seq draft length from the last draft() call, used in accept() to
|
||||
// roll back ctx_dft's recurrent state past the AR draft's redundant
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
std::vector<int> i_last;
|
||||
std::vector<std::vector<float>> chain_h;
|
||||
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
@@ -889,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@@ -935,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
|
||||
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
|
||||
|
||||
if (chain_heads) {
|
||||
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
|
||||
|
||||
chain_h.assign(n_seq, {});
|
||||
for (auto & c : chain_h) {
|
||||
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_last.assign(n_seq, -1);
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
|
||||
verify_h.assign(n_seq, {});
|
||||
verify_h_rows.assign(n_seq, 0);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
@@ -1050,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
bool ok = true;
|
||||
for (int head = 0; head < n_mtp_layers; ++head) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, head);
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1087,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1102,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
if (chain_heads) {
|
||||
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
|
||||
}
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
// each step decodes under a different head, i.e. a different decoder layer, and
|
||||
// KV is per layer. process() filled this layer's KV only for positions < n_past
|
||||
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
|
||||
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
|
||||
// and select head i so it rebuilds its own layer's KV there; decoding just the
|
||||
// latest token would leave its attention reading cells only another head wrote.
|
||||
if (chain_heads) {
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (drafting[seq_id]) {
|
||||
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
|
||||
}
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, i);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
// rebuild the batch for the next step: the growing-KV paths re-add only the
|
||||
// new token (the KV already holds the prefix), while chained heads re-add the
|
||||
// whole prefix at the next head. dropped sequences are simply not re-added.
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1127,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
|
||||
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
@@ -1163,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_mem_shared) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
|
||||
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
|
||||
|
||||
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
|
||||
for (int t = 0; t < n_rows; ++t) {
|
||||
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
|
||||
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
|
||||
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
|
||||
}
|
||||
} else if (is_mem_shared) {
|
||||
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
|
||||
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
|
||||
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
} else {
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
}
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
@@ -1196,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1810,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1848,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
}
|
||||
@@ -2118,6 +2232,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: support the case of more than one speculative implementations having a state
|
||||
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
|
||||
if (spec == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
if (impl->get_state(seq_id, data)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto & impl : spec->impls) {
|
||||
impl->set_state(seq_id, data);
|
||||
}
|
||||
}
|
||||
|
||||
void common_speculative_print_stats(const common_speculative * spec) {
|
||||
if (spec == nullptr) {
|
||||
return;
|
||||
|
||||
@@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
|
||||
// informs the speculative context that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
|
||||
|
||||
// (optional) get/set internal state
|
||||
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data);
|
||||
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
|
||||
@@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel):
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
|
||||
+7
-1
@@ -1119,8 +1119,10 @@ class TextModel(ModelBase):
|
||||
|
||||
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
|
||||
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
|
||||
partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True)
|
||||
original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True)
|
||||
|
||||
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
|
||||
# Ensure global params are mirrored in rope_parameters
|
||||
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
|
||||
if local_rope_theta is not None:
|
||||
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
|
||||
@@ -1128,6 +1130,10 @@ class TextModel(ModelBase):
|
||||
self.rope_parameters["rope_theta"] = rope_theta
|
||||
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
|
||||
self.rope_parameters["rope_type"] = rope_type
|
||||
if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None:
|
||||
self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
||||
if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None:
|
||||
self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings
|
||||
|
||||
@classmethod
|
||||
def __init_subclass__(cls):
|
||||
|
||||
@@ -148,7 +148,7 @@ class ChatGLMModel(TextModel):
|
||||
rope_dim = self.hparams["attention_dim"]
|
||||
else:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
rope_freq = 10000
|
||||
if "rope_ratio" in self.hparams:
|
||||
|
||||
+1
-1
@@ -161,7 +161,7 @@ class DeciModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
@@ -24,7 +24,7 @@ class ExaoneModel(TextModel):
|
||||
|
||||
assert (hparams["activation_function"] == "silu")
|
||||
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
|
||||
rotary_factor = self.rope_parameters.get("partial_rotary_factor")
|
||||
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
|
||||
@@ -39,7 +39,7 @@ class ExaoneModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
@@ -104,7 +104,7 @@ class Exaone4Model(TextModel):
|
||||
factor = rope_params.get("factor", 16.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model):
|
||||
self.gguf_writer.add_head_count_kv(value_arr)
|
||||
|
||||
# handle n_rot differently for global vs swa layers
|
||||
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
|
||||
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
|
||||
self.gguf_writer.add_rope_dimension_count(n_rot_full)
|
||||
|
||||
+2
-2
@@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel):
|
||||
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
)
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
|
||||
int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))
|
||||
)
|
||||
|
||||
# MoE parameters - Use only routed expert count (shared experts handled separately)
|
||||
@@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
rope_dim = self.hparams["qk_rope_head_dim"]
|
||||
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
|
||||
|
||||
# NextN/MTP prediction layers
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
+1
-1
@@ -289,7 +289,7 @@ class LlamaModel(TextModel):
|
||||
factor = rope_params.get("factor", 8.0)
|
||||
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -154,7 +154,7 @@ class MimoV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
|
||||
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
|
||||
rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"])
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
|
||||
|
||||
+6
-10
@@ -32,11 +32,9 @@ class MiniCPMModel(TextModel):
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
@@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel):
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if long_factors or short_factors:
|
||||
rope_dims = self.hparams["qk_rope_head_dim"]
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
@@ -125,17 +125,18 @@ class NemotronModel(TextModel):
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
|
||||
# * Partial RoPE
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||
|
||||
# * RopeScaling for Nemotron
|
||||
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
|
||||
factor = self.hparams.get("factor") or self.rope_parameters.get("factor")
|
||||
if factor is None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
|
||||
self.gguf_writer.add_rope_scaling_factor(factor)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
|
||||
|
||||
+9
-11
@@ -18,7 +18,7 @@ class Phi2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHI2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor"])
|
||||
rot_pct = self.rope_parameters["partial_rotary_factor"]
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
|
||||
@@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel):
|
||||
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
rms_eps = self.find_hparam(["rms_norm_eps"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
self.gguf_writer.add_context_length(max_pos_embds)
|
||||
@@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel):
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
|
||||
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
|
||||
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
|
||||
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
|
||||
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
rope_dims = int(rot_pct * n_embd) // n_head
|
||||
|
||||
# write rope scaling for long context (128k) model
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
long_factors = self.rope_parameters.get('long_factor')
|
||||
short_factors = self.rope_parameters.get('short_factor')
|
||||
if not long_factors:
|
||||
return
|
||||
|
||||
scale = max_pos_embds / orig_max_pos_embds
|
||||
|
||||
rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower()
|
||||
rope_scaling_type = self.rope_parameters.get('rope_type', '').lower()
|
||||
if len(rope_scaling_type) == 0:
|
||||
raise KeyError('Missing the required key rope_scaling.type')
|
||||
|
||||
@@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
|
||||
+1
-1
@@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel):
|
||||
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
|
||||
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25)))
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
|
||||
@@ -28,7 +28,7 @@ class StableLMModel(TextModel):
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
|
||||
rotary_factor = self.rope_parameters["partial_rotary_factor"]
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
|
||||
+1
-1
@@ -314,7 +314,7 @@ class Step35Model(TextModel):
|
||||
factor = float(rope_params.get("factor", 8.0))
|
||||
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
|
||||
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
|
||||
old_context_len = int(rope_params.get("original_max_position_embeddings", 8192))
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
|
||||
|
||||
```
|
||||
$ apt update && apt upgrade -y
|
||||
$ apt install git cmake
|
||||
$ apt install git cmake libandroid-spawn
|
||||
```
|
||||
|
||||
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
|
||||
|
||||
+36
-38
@@ -8,55 +8,53 @@ The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/lla
|
||||
|
||||
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
|
||||
|
||||
### Using a Remote Preset
|
||||
### Using a Hugging Face Preset
|
||||
|
||||
> [!NOTE]
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> This feature is currently only supported via the `-hf` option.
|
||||
> Please only use presets that you can trust! Unknown presets may be unsafe
|
||||
|
||||
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
|
||||
You can push your preset to Hugging Face Hub and share with other users by:
|
||||
1. Creating an empty model repository on Hugging Face
|
||||
2. Creating a `preset.ini` file in the root directory of the repository
|
||||
|
||||
Example:
|
||||
Example of a `preset.ini`:
|
||||
|
||||
```ini
|
||||
hf-repo-draft = username/my-draft-model-GGUF
|
||||
temp = 0.5
|
||||
top-k = 20
|
||||
top-p = 0.95
|
||||
[*]
|
||||
ctx-size = 0
|
||||
mmap = 1
|
||||
kv-unified = 1
|
||||
parallel = 4
|
||||
spec-default = 1
|
||||
|
||||
[Qwen3.5-4B]
|
||||
hf = unsloth/Qwen3.5-4B-GGUF:Q4_K_M
|
||||
ctx-size = 262144
|
||||
batch-size = 2048
|
||||
ubatch-size = 2048
|
||||
top-p = 1.0
|
||||
top-k = 0
|
||||
min-p = 0.01
|
||||
temp = 1.0
|
||||
|
||||
[gpt-oss-120b-hf]
|
||||
hf = ggml-org/gpt-oss-120b-GGUF
|
||||
ctx-size = 262144
|
||||
batch-size = 2048
|
||||
ubatch-size = 2048
|
||||
top-p = 1.0
|
||||
top-k = 0
|
||||
min-p = 0.01
|
||||
temp = 1.0
|
||||
chat-template-kwargs = {"reasoning_effort": "high"}
|
||||
```
|
||||
|
||||
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
|
||||
|
||||
Example usage:
|
||||
|
||||
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
|
||||
|
||||
```sh
|
||||
llama-cli -hf username/my-model-with-preset
|
||||
|
||||
# This is equivalent to:
|
||||
llama-cli -hf username/my-model-with-preset \
|
||||
--hf-repo-draft username/my-draft-model-GGUF \
|
||||
--temp 0.5 \
|
||||
--top-k 20 \
|
||||
--top-p 0.95
|
||||
```
|
||||
|
||||
You can also override preset arguments by specifying them on the command line:
|
||||
The preset will be loaded similarly to the `--models-preset` option. Therefore, you can also override certain params via CLI arguments:
|
||||
|
||||
```sh
|
||||
# Force temp = 0.1, overriding the preset value
|
||||
llama-cli -hf username/my-model-with-preset --temp 0.1
|
||||
```
|
||||
|
||||
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
|
||||
|
||||
```ini
|
||||
hf-repo = user/my-model-main
|
||||
hf-repo-draft = user/my-model-draft
|
||||
temp = 0.8
|
||||
ctx-size = 1024
|
||||
; (and other configurations)
|
||||
llama-cli -hf username/my-preset --temp 0.1
|
||||
```
|
||||
|
||||
### Named presets
|
||||
|
||||
@@ -198,18 +198,18 @@ class BuiltinRule:
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'boolean' : BuiltinRule('("true" | "false")', []),
|
||||
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
|
||||
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
|
||||
'null' : BuiltinRule('"null"', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
@@ -319,7 +319,7 @@ class SchemaConverter:
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
@@ -549,7 +549,7 @@ class SchemaConverter:
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
@@ -580,10 +580,10 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
@@ -624,7 +624,7 @@ class SchemaConverter:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
@@ -638,12 +638,12 @@ class SchemaConverter:
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
' space "]"')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
@@ -663,7 +663,7 @@ class SchemaConverter:
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
@@ -680,7 +680,7 @@ class SchemaConverter:
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
out.append(")")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
@@ -765,7 +765,7 @@ class SchemaConverter:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
rule += ' space "}"'
|
||||
|
||||
return rule
|
||||
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
|
||||
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
|
||||
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
|
||||
|
||||
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
|
||||
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
|
||||
parallel_for_ggml(params, n_batch * M, [&](int begin, int end) {
|
||||
for (int idx = begin; idx < end; ++idx) {
|
||||
int batch_idx = idx / M;
|
||||
int m = idx % M;
|
||||
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
|
||||
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
|
||||
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
|
||||
|
||||
for (int m = 0; m < M; ++m) {
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC {
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
@@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
int64_t k_cur = MIN(kc, k - kk);
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+50
-23
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
const float * xf = (const float *) x;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, xf);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * yf = (float *) y;
|
||||
float variance = 0;
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
mean = -mean;
|
||||
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
|
||||
vDSP_measqv(yf, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, yf, scale);
|
||||
} else {
|
||||
float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += *(const float *) (x + i00*nb00);
|
||||
}
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float variance = 0.0f;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float v = *(const float *) (x + i00*nb00) - mean;
|
||||
*(float *) (y + i00*nb0) = v;
|
||||
variance += v * v;
|
||||
}
|
||||
variance /= ne00;
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
*(float *) (y + i00*nb0) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
sum += (ggml_float)(xi * xi);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, (float *) y, scale);
|
||||
} else {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
*(float *) (y + i00*nb0) = xi * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
#include "col2im-1d.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach)
|
||||
// columns: [K*OC, T_in] -> output: [T_out, OC]
|
||||
// Supports F32, F16, BF16 data with F32 accumulator.
|
||||
|
||||
template <typename T>
|
||||
static __global__ void col2im_1d_kernel(
|
||||
const T * __restrict__ col,
|
||||
T * __restrict__ dst,
|
||||
const int T_in, const uint3 T_out_fd,
|
||||
const int OC, const int K, const int K_OC,
|
||||
const int s0, const int p0, const int total) {
|
||||
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx >= total) return;
|
||||
|
||||
// dst layout: [T_out, OC], ne[0]=T_out fastest
|
||||
const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out
|
||||
const int oc = (int)qr.x;
|
||||
const int t_out = (int)qr.y;
|
||||
const int t_abs = t_out + p0; // absolute position in uncropped signal
|
||||
|
||||
// Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K
|
||||
int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s)
|
||||
if (t_in_min < 0) t_in_min = 0;
|
||||
int t_in_max = t_abs / s0;
|
||||
if (t_in_max >= T_in) t_in_max = T_in - 1;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int t_in = t_in_min; t_in <= t_in_max; t_in++) {
|
||||
const int k = t_abs - t_in * s0;
|
||||
// col layout: [K*OC, T_in], column index = oc * K + k
|
||||
sum += ggml_cuda_cast<float>(col[(oc * K + k) + t_in * K_OC]);
|
||||
}
|
||||
|
||||
dst[idx] = ggml_cuda_cast<T>(sum);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||
|
||||
const int K_OC = (int) src0->ne[0];
|
||||
const int T_in = (int) src0->ne[1];
|
||||
const int K = K_OC / OC;
|
||||
const int T_out = (int) dst->ne[0];
|
||||
|
||||
const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out);
|
||||
|
||||
const int total = T_out * OC;
|
||||
const int block_size = 256;
|
||||
const int num_blocks = (total + block_size - 1) / block_size;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const float *)src0->data, (float *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const half *)src0->data, (half *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
|
||||
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
|
||||
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("col2im_1d: unsupported type");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ggml-cuda/argsort.cuh"
|
||||
#include "ggml-cuda/binbcast.cuh"
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/col2im-1d.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
@@ -3051,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
||||
break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
ggml_cuda_op_col2im_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
@@ -5316,13 +5320,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_COL2IM_1D:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) &&
|
||||
op->type == src0_type &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op);
|
||||
} break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
|
||||
@@ -69,6 +69,7 @@ static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
|
||||
static int opt_opbatch = 1024; // max number of ops in a batch
|
||||
static int opt_opqueue = 16; // max number of pending batches
|
||||
static int opt_oppoll = 0; // polling for batch completions
|
||||
static int opt_optrace = 0; // trace buffer size per thread (0 means default)
|
||||
|
||||
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
|
||||
|
||||
@@ -118,20 +119,39 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
|
||||
ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no");
|
||||
}
|
||||
|
||||
static const char * htp_event_name(uint16_t id) {
|
||||
switch (id) {
|
||||
case HTP_TRACE_EVT_DMA: return "DMA";
|
||||
case HTP_TRACE_EVT_HVX_COMP: return "HVX_COMP";
|
||||
case HTP_TRACE_EVT_HVX_A_QUANT: return "HVX_A_QUANT";
|
||||
case HTP_TRACE_EVT_HVX_A_PREP: return "HVX_A_PREP";
|
||||
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
|
||||
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
|
||||
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
|
||||
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node,
|
||||
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
|
||||
const htp_prof_desc & pd) {
|
||||
if (!opt_profile) return;
|
||||
|
||||
uint32_t op_usec = pd.usecs;
|
||||
uint32_t op_cycles = pd.cycles_stop - pd.cycles_start;
|
||||
const uint32_t * pmu = pd.pmu;
|
||||
|
||||
char pmu_str[256] = "";
|
||||
if (opt_profile > 1) {
|
||||
if (opt_profile == 2) {
|
||||
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
|
||||
sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]",
|
||||
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
|
||||
}
|
||||
|
||||
htp_opformat fmt(node);
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
|
||||
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str);
|
||||
float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f;
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(),
|
||||
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str);
|
||||
}
|
||||
|
||||
// ** backend sessions
|
||||
@@ -1995,10 +2015,16 @@ struct ggml_hexagon_opqueue {
|
||||
size_t n_ops = batch_size;
|
||||
size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS;
|
||||
|
||||
size_t tr_size = 0;
|
||||
if (opt_profile == 3) {
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * opt_optrace * sizeof(htp_trace_desc);
|
||||
}
|
||||
|
||||
shm_blk_size = sizeof(htp_buf_desc) * n_bufs +
|
||||
sizeof(htp_tensor) * n_tensors +
|
||||
sizeof(htp_op_desc) * n_ops +
|
||||
sizeof(htp_prof_desc) * n_ops;
|
||||
sizeof(htp_prof_desc) * n_ops +
|
||||
tr_size;
|
||||
|
||||
shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */);
|
||||
|
||||
@@ -2042,11 +2068,19 @@ struct ggml_hexagon_opqueue {
|
||||
const size_t o_size = sizeof(htp_op_desc) * req.n_ops;
|
||||
const size_t p_size = sizeof(htp_prof_desc) * req.n_ops;
|
||||
|
||||
size_t tr_size = 0;
|
||||
if (opt_profile == 3) {
|
||||
req.n_traces = opt_optrace;
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(htp_trace_desc);
|
||||
} else {
|
||||
req.n_traces = 0;
|
||||
}
|
||||
|
||||
dbuf.ptr = shm_buf->base + (req.id * shm_blk_size);
|
||||
dbuf.fd = shm_buf->fd;
|
||||
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
|
||||
dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base;
|
||||
dbuf.size = b_size + t_size + o_size + p_size;
|
||||
dbuf.size = b_size + t_size + o_size + p_size + tr_size;
|
||||
|
||||
GGML_ASSERT(dbuf.size <= shm_blk_size);
|
||||
|
||||
@@ -2092,7 +2126,14 @@ struct ggml_hexagon_opqueue {
|
||||
const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops;
|
||||
const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops;
|
||||
|
||||
const size_t m_size = b_size + t_size + o_size + p_size;
|
||||
size_t tr_size = 0;
|
||||
uint32_t n_traces = 0;
|
||||
if (opt_profile == 3) {
|
||||
n_traces = opt_optrace;
|
||||
tr_size = (HTP_MAX_NTHREADS + 1) * n_traces * sizeof(htp_trace_desc);
|
||||
}
|
||||
|
||||
const size_t m_size = b_size + t_size + o_size + p_size + tr_size;
|
||||
GGML_ASSERT(m_size <= shm_blk_size);
|
||||
|
||||
HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n",
|
||||
@@ -2111,13 +2152,62 @@ struct ggml_hexagon_opqueue {
|
||||
GGML_ASSERT(rsp.n_ops <= ops.size());
|
||||
|
||||
const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr;
|
||||
for (uint32_t i = 0; i < rsp.n_ops; i++) {
|
||||
htp_usec += pd[i].usecs;
|
||||
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu);
|
||||
|
||||
const htp_trace_desc * trace_events = nullptr;
|
||||
|
||||
if (opt_profile == 3) {
|
||||
trace_events = (const htp_trace_desc *) (p_ptr + p_size);
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n",
|
||||
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec);
|
||||
uint32_t trace_idx[HTP_MAX_NTHREADS + 1] = {0};
|
||||
uint32_t valid_cnt[HTP_MAX_NTHREADS + 1] = {0};
|
||||
|
||||
if (opt_profile == 3) {
|
||||
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
uint32_t count = rsp.n_traces[t];
|
||||
valid_cnt[t] = count > n_traces ? n_traces : count;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < rsp.n_ops; i++) {
|
||||
htp_usec += pd[i].usecs;
|
||||
|
||||
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i]);
|
||||
|
||||
if (opt_profile == 3) {
|
||||
uint32_t op_duration = pd[i].cycles_stop - pd[i].cycles_start;
|
||||
|
||||
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
while (trace_idx[t] < valid_cnt[t]) {
|
||||
const auto & e = trace_events[t * n_traces + trace_idx[t]];
|
||||
uint32_t offset = e.cycles - pd[i].cycles_start;
|
||||
if (offset >= 0x80000000) {
|
||||
trace_idx[t]++;
|
||||
continue;
|
||||
}
|
||||
if (offset > op_duration) {
|
||||
break;
|
||||
}
|
||||
bool is_stop = (e.info & 0x8000) != 0;
|
||||
uint16_t info = e.info & 0x7FFF;
|
||||
GGML_LOG_DEBUG("ggml-hex: %s trace-op %s: thread %u event %s info %u %s %u\n",
|
||||
shm_buf->sess->c_name(), ops[i].op_name().c_str(), t, htp_event_name(e.id), info, is_stop ? "stop" : "start", e.cycles);
|
||||
trace_idx[t]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
char evt_str[256] = "";
|
||||
if (opt_profile == 3) {
|
||||
sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]",
|
||||
rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3],
|
||||
rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7],
|
||||
rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]);
|
||||
}
|
||||
|
||||
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u%s\n",
|
||||
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec, evt_str);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -3901,6 +3991,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
|
||||
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
|
||||
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
|
||||
const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE");
|
||||
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
|
||||
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
|
||||
const char * str_etm = getenv("GGML_HEXAGON_ETM");
|
||||
@@ -3939,6 +4030,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
|
||||
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
|
||||
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
|
||||
opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128);
|
||||
opt_profile = str_profile ? atoi(str_profile) : 0;
|
||||
opt_etm = str_etm ? atoi(str_etm) : 0;
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
|
||||
@@ -37,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-matmul-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
|
||||
@@ -339,6 +339,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
|
||||
if (ir0 >= ir1) return;
|
||||
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
dma_queue * dma = octx->ctx->dma[ith];
|
||||
|
||||
const uint32_t DK = nek0;
|
||||
@@ -615,6 +618,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
|
||||
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hex-profile.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -88,6 +90,7 @@ typedef struct {
|
||||
uint32_t pop_idx;
|
||||
uint32_t capacity;
|
||||
uint32_t idx_mask;
|
||||
struct htp_thread_trace * trace;
|
||||
} dma_queue;
|
||||
|
||||
dma_queue * dma_queue_create(size_t capacity);
|
||||
@@ -152,6 +155,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (size) {
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
} else {
|
||||
@@ -202,6 +206,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (nrows) {
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
} else {
|
||||
@@ -223,10 +228,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
||||
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
|
||||
|
||||
// Wait for desc to complete
|
||||
while (!desc->done) {
|
||||
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
|
||||
dmpoll();
|
||||
if (!desc->done) {
|
||||
while (!desc->done) {
|
||||
dmpoll();
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_DMA, q->pop_idx);
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
#ifndef HEX_PROFILE_H
|
||||
#define HEX_PROFILE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <qurt.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#define HTP_TRACE_EVT_START 0
|
||||
#define HTP_TRACE_EVT_STOP 1
|
||||
|
||||
#ifndef HEX_NUM_PMU_COUNTERS
|
||||
#define HEX_NUM_PMU_COUNTERS 8
|
||||
#endif
|
||||
|
||||
static inline void hex_get_pmu(uint32_t counters[]) {
|
||||
#if __HVX_ARCH__ >= 79
|
||||
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
|
||||
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
|
||||
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
|
||||
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
|
||||
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
|
||||
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
|
||||
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
|
||||
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
|
||||
#else
|
||||
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
|
||||
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
|
||||
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
|
||||
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
|
||||
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
|
||||
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
|
||||
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
|
||||
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
|
||||
#endif
|
||||
}
|
||||
|
||||
struct htp_thread_trace {
|
||||
uint32_t count;
|
||||
uint32_t max_events;
|
||||
struct htp_trace_desc * events;
|
||||
};
|
||||
|
||||
static inline void htp_trace_event(struct htp_thread_trace * tr, uint16_t id, uint16_t info, uint32_t type) {
|
||||
if (tr && tr->events && tr->count < tr->max_events) {
|
||||
uint32_t idx = tr->count;
|
||||
tr->events[idx].id = id;
|
||||
tr->events[idx].info = info | (type == HTP_TRACE_EVT_STOP ? 0x8000 : 0);
|
||||
tr->events[idx].cycles = (uint32_t) hex_get_cycles();
|
||||
tr->count++;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void htp_trace_event_start(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
|
||||
htp_trace_event(tr, id, info, HTP_TRACE_EVT_START);
|
||||
}
|
||||
|
||||
static inline void htp_trace_event_stop(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
|
||||
htp_trace_event(tr, id, info, HTP_TRACE_EVT_STOP);
|
||||
}
|
||||
|
||||
#endif /* HEX_PROFILE_H */
|
||||
@@ -107,31 +107,4 @@ static inline void hex_pause() {
|
||||
asm volatile(" pause(#255)\n");
|
||||
}
|
||||
|
||||
#ifndef HEX_NUM_PMU_COUNTERS
|
||||
#define HEX_NUM_PMU_COUNTERS 8
|
||||
#endif
|
||||
|
||||
static inline void hex_get_pmu(uint32_t counters[]) {
|
||||
#if __HVX_ARCH__ >= 79
|
||||
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
|
||||
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
|
||||
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
|
||||
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
|
||||
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
|
||||
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
|
||||
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
|
||||
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
|
||||
#else
|
||||
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
|
||||
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
|
||||
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
|
||||
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
|
||||
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
|
||||
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
|
||||
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
|
||||
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
|
||||
// qurt_pmu_get_pmucnt(counters);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif /* HEX_UTILS_H */
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#include "ggml-common.h"
|
||||
#include "hex-dma.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hmx-profile.h"
|
||||
#include "hex-profile.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "hmx-utils.h"
|
||||
#include "htp-ctx.h"
|
||||
@@ -367,8 +367,11 @@ static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
|
||||
(int) args->src_stride, start, end);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
|
||||
@@ -408,8 +411,11 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
|
||||
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
|
||||
(int) args->src_stride, (int) args->n_col_tiles, start, end);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
|
||||
@@ -462,6 +468,9 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
|
||||
const struct htp_tensor * q = args->q;
|
||||
const uint32_t q_start = args->q_start;
|
||||
const uint32_t kv_head = args->kv_head;
|
||||
@@ -515,6 +524,7 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
|
||||
}
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_q_load(struct hmx_fa_context * factx,
|
||||
@@ -566,6 +576,9 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
|
||||
const struct htp_tensor * dst = args->dst;
|
||||
const __fp16 * o_tile_src = args->o_tile_src;
|
||||
const uint32_t q_start = args->q_start;
|
||||
@@ -611,6 +624,7 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
|
||||
}
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
}
|
||||
|
||||
static void fa_phase_o_store(struct hmx_fa_context * factx,
|
||||
@@ -680,6 +694,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
|
||||
|
||||
// Per-thread row scratch: thread i uses bufs at offset i * 2 * stride
|
||||
const size_t row_buf_stride = factx->row_buf_stride;
|
||||
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
|
||||
@@ -950,6 +967,7 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
|
||||
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
|
||||
}
|
||||
|
||||
// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads).
|
||||
@@ -1245,6 +1263,7 @@ static __attribute__((noinline)) void fa_compute_slopes(
|
||||
// ============================================================================
|
||||
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
const struct htp_tensor * q = octx->src[0];
|
||||
const struct htp_tensor * k = octx->src[1];
|
||||
const struct htp_tensor * v = octx->src[2];
|
||||
@@ -1422,19 +1441,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
// Profiling timers
|
||||
TIMER_DEFINE(total);
|
||||
TIMER_DEFINE(q_load);
|
||||
TIMER_DEFINE(kv_dma);
|
||||
TIMER_DEFINE(k_interleave);
|
||||
TIMER_DEFINE(v_interleave);
|
||||
TIMER_DEFINE(qk_dot);
|
||||
TIMER_DEFINE(softmax);
|
||||
TIMER_DEFINE(o_update);
|
||||
TIMER_DEFINE(o_norm);
|
||||
TIMER_DEFINE(o_store);
|
||||
|
||||
TIMER_START(total);
|
||||
|
||||
// ======== DMA setup ========
|
||||
dma_queue * const dma = ctx->dma[0];
|
||||
@@ -1474,12 +1480,10 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
|
||||
|
||||
// ---- Load Q block [g_br, D] -> tiles, interleaving G heads ----
|
||||
TIMER_START(q_load);
|
||||
if (n_rows_g < g_br) {
|
||||
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
|
||||
}
|
||||
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
|
||||
TIMER_STOP(q_load);
|
||||
|
||||
// ---- Initialize per-block state ----
|
||||
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
|
||||
@@ -1558,10 +1562,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
// Wait for current KV DMA
|
||||
TIMER_START(kv_dma);
|
||||
dma_queue_pop(dma); // K
|
||||
dma_queue_pop(dma); // V
|
||||
TIMER_STOP(kv_dma);
|
||||
|
||||
// Push mask DMA for this block (single 2D DMA when broadcast)
|
||||
bool has_mask_dma = false;
|
||||
@@ -1583,10 +1585,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
ou_job.DV = DV;
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
|
||||
}
|
||||
|
||||
TIMER_START(k_interleave);
|
||||
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
|
||||
TIMER_STOP(k_interleave);
|
||||
|
||||
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
|
||||
qk_job.q_tiles = factx.vtcm_q_tiles;
|
||||
@@ -1597,15 +1596,11 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
qk_job.n_dot_tiles = DK / 32;
|
||||
qk_job.n_tiles_per_bc = n_tiles_per_bc;
|
||||
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
|
||||
TIMER_START(qk_dot);
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
|
||||
|
||||
// DMA push next block (non-blocking, before worker_pool)
|
||||
DMA_PREFETCH_KV(kv_blk + 1);
|
||||
|
||||
TIMER_START(v_interleave);
|
||||
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
|
||||
TIMER_STOP(v_interleave);
|
||||
|
||||
// Pop and swap previous block's output update (deferred HMX pop)
|
||||
if (kv_blk > 0) {
|
||||
@@ -1615,7 +1610,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
|
||||
// Pop current block's dot product job
|
||||
hmx_queue_pop(hmx_q);
|
||||
TIMER_STOP(qk_dot);
|
||||
|
||||
// ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ----
|
||||
// Pop mask DMA before softmax (ensures VTCM buffer is ready)
|
||||
@@ -1641,10 +1635,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
|
||||
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
|
||||
sargs.slopes = factx.vtcm_slopes;
|
||||
|
||||
TIMER_START(softmax);
|
||||
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
|
||||
TIMER_STOP(softmax);
|
||||
|
||||
buf_idx = 1 - buf_idx;
|
||||
} // end KV block loop (pipeline)
|
||||
@@ -1664,11 +1655,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
|
||||
ou_job.n_tiles_per_bc = n_tiles_per_bc;
|
||||
ou_job.DV = DV;
|
||||
|
||||
TIMER_START(o_update);
|
||||
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
|
||||
hmx_queue_pop(hmx_q);
|
||||
TIMER_STOP(o_update);
|
||||
|
||||
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
|
||||
}
|
||||
@@ -1683,23 +1671,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const uint32_t kv_start = kv_blk * Bc;
|
||||
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
|
||||
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
TIMER_START(kv_dma);
|
||||
dma_queue_pop(dma); // K
|
||||
dma_queue_pop(dma); // V
|
||||
TIMER_STOP(kv_dma);
|
||||
|
||||
bool has_mask_dma = false;
|
||||
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
|
||||
DMA_PREFETCH_KV(kv_blk + 1);
|
||||
|
||||
// K interleave (multi-thread HVX)
|
||||
TIMER_START(k_interleave);
|
||||
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
|
||||
TIMER_STOP(k_interleave);
|
||||
|
||||
// QK dot (inline HMX on main thread)
|
||||
TIMER_START(qk_dot);
|
||||
{
|
||||
const size_t n_dot_tiles = (size_t) (DK / 32);
|
||||
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
|
||||
@@ -1709,6 +1688,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < n_col_tiles; ++c) {
|
||||
@@ -1724,8 +1704,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
TIMER_STOP(qk_dot);
|
||||
|
||||
// Pop mask DMA
|
||||
MASK_DMA_POP(has_mask_dma);
|
||||
@@ -1751,21 +1731,9 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
|
||||
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
|
||||
sargs.slopes = factx.vtcm_slopes;
|
||||
|
||||
TIMER_START(softmax);
|
||||
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
|
||||
TIMER_STOP(softmax);
|
||||
|
||||
// V interleave (multi-thread HVX)
|
||||
TIMER_START(v_interleave);
|
||||
// FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout
|
||||
// stride to match o_update's v_tile access. Using per-block n_col_tiles
|
||||
// misplaces DV_tile 1..3 in the last partial KV block.
|
||||
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
|
||||
TIMER_STOP(v_interleave);
|
||||
|
||||
// O update (inline HMX on main thread)
|
||||
TIMER_START(o_update);
|
||||
{
|
||||
const size_t DV_tiles = (size_t) (DV / 32);
|
||||
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
|
||||
@@ -1777,6 +1745,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(DV_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < DV_tiles; ++c) {
|
||||
@@ -1798,16 +1767,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(o_tile_out, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
|
||||
}
|
||||
TIMER_STOP(o_update);
|
||||
|
||||
buf_idx = 1 - buf_idx;
|
||||
} // end KV block loop (fallback)
|
||||
}
|
||||
|
||||
// ---- Final normalization: O = diag(1/l) @ O ----
|
||||
TIMER_START(o_norm);
|
||||
{
|
||||
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
|
||||
|
||||
@@ -1830,6 +1798,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(DV_tiles > 0);
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
|
||||
for (size_t r = 0; r < n_row_tiles; ++r) {
|
||||
for (size_t c = 0; c < DV_tiles; ++c) {
|
||||
@@ -1842,14 +1811,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
Q6_mxmem_AR_after_hf(o_out, 0);
|
||||
}
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
}
|
||||
TIMER_STOP(o_norm);
|
||||
|
||||
// ---- Store O block ----
|
||||
TIMER_START(o_store);
|
||||
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
|
||||
TIMER_STOP(o_store);
|
||||
|
||||
#undef MASK_DMA_PUSH
|
||||
#undef MASK_DMA_POP
|
||||
@@ -1865,14 +1832,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
|
||||
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
|
||||
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
|
||||
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
|
||||
#endif
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
#include "hmx-ops.h"
|
||||
#include "hmx-utils.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "hmx-profile.h"
|
||||
#include "hex-profile.h"
|
||||
|
||||
#include "vtcm-utils.h"
|
||||
|
||||
@@ -430,6 +430,7 @@ typedef struct {
|
||||
int n_tasks;
|
||||
int n_k_tiles;
|
||||
struct fastdiv_values n_k_tiles_div;
|
||||
struct htp_thread_trace * traces;
|
||||
} x4x2_dequantize_state_t;
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
@@ -533,11 +534,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(
|
||||
\
|
||||
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
|
||||
int start = task_id * state->n_tiles_per_task; \
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
|
||||
} \
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
|
||||
}
|
||||
|
||||
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
|
||||
@@ -657,11 +661,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
|
||||
|
||||
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
@@ -717,11 +724,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
|
||||
|
||||
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void convert_f16_weight_to_fp16_tiles_task(
|
||||
@@ -773,11 +783,14 @@ static void convert_f16_weight_to_fp16_tiles_task(
|
||||
|
||||
static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
convert_f16_weight_to_fp16_tiles_task(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
static void quantize_f32_weight_to_fp16_tiles_task(
|
||||
@@ -833,11 +846,14 @@ static void quantize_f32_weight_to_fp16_tiles_task(
|
||||
|
||||
static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) {
|
||||
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
|
||||
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
|
||||
int start = task_id * state->n_tiles_per_task;
|
||||
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
|
||||
quantize_f32_weight_to_fp16_tiles_task(state, start, end);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
|
||||
}
|
||||
|
||||
|
||||
@@ -868,6 +884,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
state.weight_type = weight_type;
|
||||
state.n_k_tiles = n_k_tiles;
|
||||
state.n_k_tiles_div = n_k_tiles_div;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
dequant_worker_fn(1, 0, &state);
|
||||
@@ -985,10 +1002,13 @@ typedef struct {
|
||||
int n_chunks_per_task;
|
||||
int n_cols;
|
||||
int n; // DDR row stride (total output columns)
|
||||
struct htp_thread_trace * traces;
|
||||
} output_transfer_task_state_t;
|
||||
|
||||
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||||
int chunk_idx = task_id * st->n_chunks_per_task;
|
||||
@@ -998,6 +1018,7 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void
|
||||
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
|
||||
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
}
|
||||
|
||||
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
|
||||
@@ -1015,6 +1036,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
|
||||
state.vtcm_src = vtcm_src;
|
||||
state.n_cols = n_cols;
|
||||
state.n = n;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
transfer_output_chunk_worker_fn(1, 0, &state);
|
||||
@@ -1086,10 +1108,13 @@ typedef struct {
|
||||
int n_chunks_per_task;
|
||||
int k_block;
|
||||
int k_stride;
|
||||
struct htp_thread_trace * traces;
|
||||
} activation_transfer_task_state_t;
|
||||
|
||||
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
|
||||
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
|
||||
// one chunk: one row
|
||||
@@ -1100,6 +1125,7 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i,
|
||||
const float *src = st->src + chunk_idx * st->k_stride;
|
||||
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
}
|
||||
|
||||
static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) {
|
||||
@@ -1117,6 +1143,7 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
|
||||
state.src = src;
|
||||
state.k_block = k_block;
|
||||
state.k_stride = k_stride;
|
||||
state.traces = ctx ? ctx->trace : NULL;
|
||||
|
||||
if (state.n_tasks == 1 || n_threads == 1) {
|
||||
transfer_activation_chunk_worker_fn(1, 0, &state);
|
||||
@@ -1245,13 +1272,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
|
||||
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
TIMER_DEFINE(hmx_core);
|
||||
TIMER_DEFINE(output_store);
|
||||
|
||||
TIMER_DEFINE(total);
|
||||
TIMER_START(total);
|
||||
|
||||
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
|
||||
|
||||
@@ -1370,7 +1391,12 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
|
||||
|
||||
// C: HMX Compute (Synchronous)
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
{
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
|
||||
// D: Output Store
|
||||
float *output_chunk = dst + (mr * n + nc);
|
||||
@@ -1380,18 +1406,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
|
||||
if (!use_pipeline) {
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
size_t weight_size = (size_t)n * row_stride;
|
||||
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
|
||||
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1523,13 +1538,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
m_chunk_n_rows, n_chunk_n_cols,
|
||||
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
|
||||
|
||||
TIMER_DEFINE(activation_load);
|
||||
TIMER_DEFINE(weight_load);
|
||||
TIMER_DEFINE(hmx_core);
|
||||
TIMER_DEFINE(output_store);
|
||||
TIMER_DEFINE(total);
|
||||
|
||||
TIMER_START(total);
|
||||
|
||||
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
|
||||
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
|
||||
@@ -1549,7 +1558,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
// contiguous rows into a VTCM scratch buffer first, then HVX
|
||||
// converts from the contiguous VTCM buffer. This avoids L2 cache
|
||||
// thrashing from HVX loads at large strides.
|
||||
TIMER_START(activation_load);
|
||||
for (int g = 0; g < group_size; ++g) {
|
||||
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
|
||||
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||||
@@ -1569,7 +1577,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
params->k, params->act_stride, ctx->n_threads);
|
||||
}
|
||||
}
|
||||
TIMER_STOP(activation_load);
|
||||
|
||||
void *buf_curr = vtcm_scratch0;
|
||||
void *buf_next = vtcm_scratch1;
|
||||
@@ -1584,7 +1591,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
|
||||
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
|
||||
|
||||
TIMER_START(weight_load);
|
||||
{
|
||||
dma_queue_pop(ctx->dma[0]);
|
||||
|
||||
@@ -1601,24 +1607,22 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
0, n_cols);
|
||||
hex_swap_ptr(&buf_curr, &buf_next);
|
||||
}
|
||||
TIMER_STOP(weight_load);
|
||||
|
||||
// Reuse the interleaved weight for every q_head in this GQA group
|
||||
for (int g = 0; g < group_size; ++g) {
|
||||
TIMER_START(hmx_core);
|
||||
{
|
||||
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
|
||||
params->k / 32);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
TIMER_STOP(hmx_core);
|
||||
|
||||
TIMER_START(output_store);
|
||||
{
|
||||
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
|
||||
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads);
|
||||
}
|
||||
TIMER_STOP(output_store);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,14 +1631,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
|
||||
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
TIMER_STOP(total);
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
|
||||
params->m, params->k, params->n, group_size);
|
||||
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
|
||||
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1668,6 +1665,7 @@ typedef struct {
|
||||
size_t nb12;
|
||||
int start_row;
|
||||
int cne1;
|
||||
struct htp_thread_trace *traces;
|
||||
} activation_transfer_gathered_task_state_t;
|
||||
|
||||
typedef struct {
|
||||
@@ -1684,6 +1682,7 @@ typedef struct {
|
||||
size_t dst_nb2;
|
||||
int start_row;
|
||||
int cne1;
|
||||
struct htp_thread_trace *traces;
|
||||
} output_transfer_scattered_task_state_t;
|
||||
|
||||
static void transfer_activation_chunk_fp32_to_fp16_gathered(
|
||||
@@ -1780,6 +1779,9 @@ static void transfer_activation_chunk_fp32_to_fp16_gathered(
|
||||
|
||||
static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
activation_transfer_gathered_task_state_t *st = data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
|
||||
int chunk_idx = i;
|
||||
int chunk_size = st->n_chunks_per_task;
|
||||
int start_row = st->start_row + chunk_idx * chunk_size;
|
||||
@@ -1791,6 +1793,7 @@ static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigne
|
||||
st->matrix_rows, st->cur_a, st->mapping_stride,
|
||||
st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
|
||||
}
|
||||
|
||||
static void transfer_activation_chunk_gathered_threaded(
|
||||
@@ -1830,6 +1833,7 @@ static void transfer_activation_chunk_gathered_threaded(
|
||||
.nb12 = nb12,
|
||||
.start_row = start_row,
|
||||
.cne1 = cne1,
|
||||
.traces = ctx ? ctx->trace : NULL,
|
||||
};
|
||||
|
||||
if (actual_threads <= 1) {
|
||||
@@ -1895,6 +1899,9 @@ static void transfer_output_chunk_fp16_to_fp32_scattered(
|
||||
|
||||
static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) {
|
||||
output_transfer_scattered_task_state_t *st = data;
|
||||
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
|
||||
int chunk_idx = i;
|
||||
int chunk_size = st->n_chunks_per_task;
|
||||
int start_row = st->start_row + chunk_idx * chunk_size;
|
||||
@@ -1906,6 +1913,7 @@ static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned i
|
||||
st->matrix_rows, st->cur_a, st->mapping_stride,
|
||||
st->dst_nb1, st->dst_nb2, st->cne1);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
|
||||
}
|
||||
|
||||
static void transfer_output_chunk_scattered_threaded(
|
||||
@@ -1942,6 +1950,7 @@ static void transfer_output_chunk_scattered_threaded(
|
||||
.dst_nb2 = dst_nb2,
|
||||
.start_row = start_row,
|
||||
.cne1 = cne1,
|
||||
.traces = ctx ? ctx->trace : NULL,
|
||||
};
|
||||
|
||||
if (actual_threads <= 1) {
|
||||
@@ -2053,7 +2062,12 @@ int hmx_matmul_id_2d_f32(struct htp_context *ctx,
|
||||
|
||||
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
|
||||
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
{
|
||||
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
|
||||
}
|
||||
|
||||
transfer_output_chunk_scattered_threaded(
|
||||
ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols,
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
// Conditional fine-grained profiling macros for HMX operations.
|
||||
//
|
||||
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
|
||||
// header) to instrument sub-operation latencies with HAP qtimer. When the
|
||||
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
|
||||
// overhead.
|
||||
//
|
||||
// Usage:
|
||||
// TIMER_DEFINE(my_phase); // declare accumulator variable
|
||||
// TIMER_START(my_phase); // snapshot start time
|
||||
// ... work ...
|
||||
// TIMER_STOP(my_phase); // accumulate elapsed ticks
|
||||
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
|
||||
|
||||
#ifndef HMX_PROFILE_H
|
||||
#define HMX_PROFILE_H
|
||||
|
||||
#include <HAP_perf.h>
|
||||
|
||||
// #define ENABLE_PROFILE_TIMERS
|
||||
|
||||
#if defined(ENABLE_PROFILE_TIMERS)
|
||||
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
|
||||
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
|
||||
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
|
||||
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
|
||||
#else
|
||||
# define TIMER_DEFINE(name)
|
||||
# define TIMER_START(name)
|
||||
# define TIMER_STOP(name)
|
||||
# define TIMER_US(name) 0LL
|
||||
#endif
|
||||
|
||||
#endif // HMX_PROFILE_H
|
||||
@@ -44,7 +44,9 @@ static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
|
||||
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
|
||||
default:
|
||||
hmx_lock(q);
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
|
||||
d->func(d->data);
|
||||
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include "hex-utils.h"
|
||||
#include "hex-profile.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@@ -47,6 +48,7 @@ struct hmx_queue {
|
||||
void * stack;
|
||||
uint32_t hap_rctx;
|
||||
bool hmx_locked;
|
||||
struct htp_thread_trace * trace;
|
||||
};
|
||||
|
||||
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "hex-dma.h"
|
||||
#include "hmx-queue.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hex-profile.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
@@ -70,6 +71,7 @@ struct htp_context {
|
||||
bool hmx_enabled;
|
||||
bool etm;
|
||||
uint32_t profiler;
|
||||
struct htp_thread_trace trace[HTP_MAX_NTHREADS + 1];
|
||||
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
|
||||
@@ -146,10 +146,36 @@ struct htp_op_desc {
|
||||
uint16_t dst; // Output tensor index
|
||||
};
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#endif
|
||||
|
||||
#define HTP_TRACE_MAX_EVENTS 256
|
||||
|
||||
enum htp_profiler_mode {
|
||||
HTP_PROF_DISABLED = 0,
|
||||
HTP_PROF_BASIC = 1,
|
||||
HTP_PROF_PMU = 2,
|
||||
HTP_PROF_TRACE = 3,
|
||||
};
|
||||
|
||||
enum htp_trace_event_id {
|
||||
HTP_TRACE_EVT_DMA = 0,
|
||||
|
||||
HTP_TRACE_EVT_HVX_COMP = 20,
|
||||
HTP_TRACE_EVT_HVX_A_QUANT = 21,
|
||||
HTP_TRACE_EVT_HVX_A_PREP = 22,
|
||||
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
|
||||
HTP_TRACE_EVT_HVX_W_PREP = 24,
|
||||
HTP_TRACE_EVT_HVX_O_PROC = 25,
|
||||
|
||||
HTP_TRACE_EVT_HMX_COMP = 40,
|
||||
};
|
||||
|
||||
struct htp_trace_desc {
|
||||
uint32_t cycles; // lower 32-bits of cycle counter
|
||||
uint16_t id; // Event ID
|
||||
uint16_t info; // bit 15: is_stop. bits 14-0: tile/chunk index or other metadata.
|
||||
};
|
||||
|
||||
#define HTP_PROF_PMU_NCNT 8
|
||||
@@ -158,8 +184,8 @@ enum htp_profiler_mode {
|
||||
struct htp_prof_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t usecs; // Number of usec
|
||||
uint32_t cycles; // Number of cycles
|
||||
uint32_t pad; // Unused
|
||||
uint32_t cycles_start; // Start cycle counter
|
||||
uint32_t cycles_stop; // Stop cycle counter
|
||||
uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters
|
||||
};
|
||||
|
||||
@@ -168,7 +194,7 @@ struct htp_opbatch_req {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of ops
|
||||
uint32_t flags; // unused
|
||||
uint32_t n_traces; // Number of trace descriptors per thread
|
||||
uint32_t pad; // unused
|
||||
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
|
||||
// struct htp_tensor tensors[]; -- dspqueue buf 0
|
||||
@@ -181,7 +207,8 @@ struct htp_opbatch_rsp {
|
||||
uint32_t n_bufs; // Number of buffers
|
||||
uint32_t n_tensors; // Number of tensors
|
||||
uint32_t n_ops; // Number of op profile descriptors
|
||||
uint32_t pad; // unused
|
||||
uint32_t n_traces[HTP_MAX_NTHREADS + 1];
|
||||
uint8_t pad[8]; // align to 8 bytes
|
||||
// struct htp_prof_desc profs[]; -- dspqueue buf 0
|
||||
};
|
||||
|
||||
|
||||
@@ -400,7 +400,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (!ctx->hmx_queue) {
|
||||
if (ctx->hmx_queue) {
|
||||
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
|
||||
} else {
|
||||
FARF(ERROR, "hmx-queue-create failed");
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
@@ -425,6 +427,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->n_threads = n_hvx;
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
ctx->dma[i] = dma_queue_create(256); // queue depth
|
||||
if (ctx->dma[i]) {
|
||||
ctx->dma[i]->trace = &ctx->trace[i];
|
||||
}
|
||||
}
|
||||
|
||||
ctx->ddr_spad_size = 512 * 1024; // 512 KB
|
||||
@@ -502,7 +507,8 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
struct profile_data {
|
||||
uint64_t usecs;
|
||||
uint64_t cycles;
|
||||
uint64_t cycles_start;
|
||||
uint64_t cycles_stop;
|
||||
uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS];
|
||||
};
|
||||
|
||||
@@ -512,8 +518,9 @@ static inline void profile_start(uint32_t mode, struct profile_data * d) {
|
||||
hex_get_pmu(d->pmu_counters);
|
||||
// fallthrough
|
||||
case HTP_PROF_BASIC:
|
||||
case HTP_PROF_TRACE:
|
||||
d->usecs = HAP_perf_get_qtimer_count();
|
||||
d->cycles = hex_get_cycles();
|
||||
d->cycles_start = hex_get_cycles();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -530,8 +537,9 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
|
||||
}
|
||||
// fallthrough
|
||||
case HTP_PROF_BASIC:
|
||||
case HTP_PROF_TRACE:
|
||||
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
|
||||
d->cycles = hex_get_cycles() - d->cycles;
|
||||
d->cycles_stop = hex_get_cycles();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
@@ -845,14 +853,15 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
const uint32_t t_size = sizeof(struct htp_tensor) * n_tens;
|
||||
const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops;
|
||||
const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops;
|
||||
const uint32_t tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(struct htp_trace_desc);
|
||||
|
||||
if (dbuf.size < b_size + t_size + o_size + p_size) {
|
||||
FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size);
|
||||
if (dbuf.size < b_size + t_size + o_size + p_size + tr_size) {
|
||||
FARF(ERROR, "invalid opbatch memory block size %u (req %u)", dbuf.size, b_size + t_size + o_size + p_size + tr_size);
|
||||
break;
|
||||
}
|
||||
|
||||
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id,
|
||||
n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size);
|
||||
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u n-traces %u : m-size %u b-size %u t-size %u o-size %u", req.id,
|
||||
n_bufs, n_tens, n_ops, req.n_traces, dbuf.size, b_size, t_size, o_size);
|
||||
|
||||
// Setup descriptor pointers
|
||||
uint8_t * m_ptr = dbuf.ptr;
|
||||
@@ -869,6 +878,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
octx->n_threads = ctx->n_threads;
|
||||
octx->ctx = ctx;
|
||||
|
||||
if (ctx->profiler == HTP_PROF_TRACE) {
|
||||
memset(ctx->trace, 0, sizeof(ctx->trace));
|
||||
struct htp_trace_desc * trace_events = (struct htp_trace_desc *) (m_ptr + p_size);
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
ctx->trace[t].events = &trace_events[t * req.n_traces];
|
||||
ctx->trace[t].max_events = req.n_traces;
|
||||
}
|
||||
} else {
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
ctx->trace[t].events = NULL;
|
||||
ctx->trace[t].max_events = 0;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
@@ -886,7 +909,8 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
if (ctx->profiler) {
|
||||
pds[i].opcode = ops[i].opcode;
|
||||
pds[i].usecs = prof.usecs;
|
||||
pds[i].cycles = prof.cycles;
|
||||
pds[i].cycles_start = prof.cycles_start;
|
||||
pds[i].cycles_stop = prof.cycles_stop;
|
||||
for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) {
|
||||
pds[i].pmu[j] = prof.pmu_counters[j];
|
||||
}
|
||||
@@ -899,6 +923,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
rsp.n_bufs = n_bufs;
|
||||
rsp.n_tensors = n_tens;
|
||||
rsp.n_ops = n_ops;
|
||||
memset(rsp.pad, 0, sizeof(rsp.pad));
|
||||
if (ctx->profiler == HTP_PROF_TRACE) {
|
||||
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
|
||||
rsp.n_traces[t] = ctx->trace[t].count;
|
||||
}
|
||||
} else {
|
||||
memset(rsp.n_traces, 0, sizeof(rsp.n_traces));
|
||||
}
|
||||
|
||||
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
|
||||
|
||||
|
||||
@@ -3350,6 +3350,7 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
|
||||
|
||||
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
@@ -3411,10 +3412,12 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||
|
||||
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
|
||||
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
|
||||
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
|
||||
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3430,6 +3433,7 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
|
||||
@@ -3477,6 +3481,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Process src1 columns in pairs (2×2 tiling)
|
||||
uint32_t ir1 = 0;
|
||||
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
|
||||
@@ -3494,6 +3500,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
|
||||
}
|
||||
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
|
||||
@@ -3511,12 +3519,14 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
#pragma unroll(2)
|
||||
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
|
||||
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
|
||||
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
@@ -3530,6 +3540,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// q8x4x2 src1 tensor is already in VTCM spad
|
||||
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const uint32_t src0_nrows = ne01;
|
||||
|
||||
@@ -3581,7 +3592,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3599,7 +3612,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 2);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
ir0 += 2;
|
||||
}
|
||||
if (ir0 < src0_end_row) {
|
||||
@@ -3607,7 +3622,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
ir0 += 1;
|
||||
}
|
||||
} else {
|
||||
@@ -3627,7 +3644,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3645,7 +3664,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
|
||||
src0_stride, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3669,6 +3690,7 @@ struct mmid_row_mapping {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
@@ -3735,6 +3757,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
for (uint32_t cid = 0; cid < cne1; ++cid) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
|
||||
const int rm1 = row_mapping.i1; // expert idx
|
||||
@@ -3746,6 +3769,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3764,6 +3788,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
src0_row_size_padded, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
for (uint32_t cid = 0; cid < cne1; ++cid) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
|
||||
const int rm1 = row_mapping.i1; // expert idx
|
||||
@@ -3775,6 +3800,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
}
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3789,6 +3815,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
// src1 tensor is already in VTCM spad
|
||||
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
htp_matmul_preamble;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * restrict ids = octx->src[2];
|
||||
struct htp_spad * restrict src2_spad = &octx->src2_spad;
|
||||
@@ -3847,7 +3874,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
// Process src0 rows
|
||||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
|
||||
// Prefetch next (n + spad_nrows) row
|
||||
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
|
||||
@@ -3865,7 +3894,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
|
||||
src0_row_size_padded, src0_row_size, 1);
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4147,6 +4178,7 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
|
||||
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4163,6 +4195,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = src->nb[1];
|
||||
@@ -4189,6 +4222,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
|
||||
|
||||
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
|
||||
@@ -4219,6 +4253,7 @@ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y,
|
||||
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4235,6 +4270,7 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = src->nb[1];
|
||||
@@ -4260,11 +4296,13 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
|
||||
|
||||
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4281,6 +4319,7 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4301,11 +4340,13 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4322,6 +4363,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4342,12 +4384,14 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
// TODO just a plain copy that should be done via the DMA during the Op setup
|
||||
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_matmul_context * mmctx = data;
|
||||
struct htp_ops_context * octx = mmctx->octx;
|
||||
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
|
||||
|
||||
const struct htp_tensor * src = octx->src[1];
|
||||
uint8_t * restrict dst = octx->src1_spad.data;
|
||||
@@ -4364,6 +4408,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
|
||||
|
||||
const uint32_t ir_first = nrows_per_thread * ith; // first row
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
|
||||
|
||||
const size_t src_row_size = ne0 * sizeof(float);
|
||||
@@ -4384,6 +4429,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
|
||||
|
||||
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
|
||||
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_inner_stride,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
src1_T[j * d_inner_stride + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_stride = scctx->nrows_per_thread;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
|
||||
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
|
||||
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
#endif
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
|
||||
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
|
||||
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_abs(T x) {
|
||||
return sycl::fabs(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
|
||||
} else {
|
||||
return sycl::fabs(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::expm1(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::expm1(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_elu(T x) {
|
||||
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
||||
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
constexpr int ver = __INTEL_LLVM_COMPILER;
|
||||
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
|
||||
return sycl::ext::oneapi::experimental::tanh(x);
|
||||
#else
|
||||
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
|
||||
#endif
|
||||
} else {
|
||||
return sycl::tanh(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
|
||||
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
||||
return static_cast<T>(0.5f) * x *
|
||||
(static_cast<T>(1.0f) +
|
||||
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::exp(x);
|
||||
} else {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_silu(T x) {
|
||||
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return x / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
static __dpct_inline__ T op_erf(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::erf(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::erf(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_erf(T x) {
|
||||
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
return sycl::tanh(x);
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_relu(T x) {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sigmoid(T x) {
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sqrt(T x) {
|
||||
return sycl::sqrt(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sqrt(x);
|
||||
} else {
|
||||
return sycl::sqrt(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sin(T x) {
|
||||
return sycl::sin(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sin(x);
|
||||
} else {
|
||||
return sycl::sin(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_cos(T x) {
|
||||
return sycl::cos(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::cos(x);
|
||||
} else {
|
||||
return sycl::cos(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardsigmoid(T x) {
|
||||
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmin(
|
||||
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
|
||||
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return sycl::fmin(static_cast<T>(1.0f),
|
||||
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardswish(T x) {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
return sycl::expm1(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
|
||||
if (x <= static_cast<T>(0)) {
|
||||
return neg_infinity<T>();
|
||||
}
|
||||
return sycl::log(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::log(x);
|
||||
} else {
|
||||
return sycl::log(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float ax = op_abs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
||||
T neg_slope_T = static_cast<T>(negative_slope);
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
|
||||
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::floor(x);
|
||||
} else {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::ceil(x);
|
||||
} else {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::round(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::round(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::trunc(x);
|
||||
} else {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename F>
|
||||
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
||||
});
|
||||
}, negative_slope);
|
||||
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
||||
});
|
||||
}, min_val, max_val);
|
||||
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
||||
const int64_t n, const int64_t o0, const int64_t o1,
|
||||
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
|
||||
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_CHECK_RESULTS)
|
||||
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
||||
# the result-checking path computes a CPU reference graph via
|
||||
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN_DEBUG)
|
||||
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_RUN_TESTS)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
|
||||
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_conv3d_pipeline_state {
|
||||
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
|
||||
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
|
||||
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
|
||||
|
||||
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
|
||||
uint32_t aligned;
|
||||
|
||||
bool operator<(const vk_conv3d_pipeline_state &b) const {
|
||||
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
|
||||
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_solve_tri_pipeline_state {
|
||||
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
|
||||
: N(N), K(K) {}
|
||||
@@ -777,6 +791,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_back_f32;
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
@@ -801,14 +816,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_diag[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_clamp[2];
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
|
||||
@@ -840,6 +851,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
vk_pipeline pipeline_silu[2];
|
||||
vk_pipeline pipeline_relu[2];
|
||||
vk_pipeline pipeline_sqr[2];
|
||||
vk_pipeline pipeline_sqrt[2];
|
||||
vk_pipeline pipeline_sin[2];
|
||||
vk_pipeline pipeline_cos[2];
|
||||
vk_pipeline pipeline_xielu[2];
|
||||
vk_pipeline pipeline_neg[2];
|
||||
vk_pipeline pipeline_tanh[2];
|
||||
@@ -871,7 +886,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_leaky_relu[2];
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
@@ -924,6 +939,8 @@ struct vk_device_struct {
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
}
|
||||
|
||||
struct vk_op_conv3d_push_constants {
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
};
|
||||
|
||||
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
|
||||
}
|
||||
|
||||
struct vk_op_conv2d_dw_push_constants {
|
||||
uint32_t ne;
|
||||
uint32_t batches;
|
||||
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#endif
|
||||
|
||||
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
return spec;
|
||||
};
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
|
||||
if (mul_mat_id && spec.size() > 5) {
|
||||
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
|
||||
} else {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
}
|
||||
if (mul_mat_id && spec.size() == 6) {
|
||||
spec.push_back(32);
|
||||
}
|
||||
return spec;
|
||||
};
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
// bf16 scalar path promotes to f32, no dot2 variant
|
||||
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
|
||||
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) \
|
||||
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -4903,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
@@ -5023,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -5037,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5058,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_UNARY(gelu_quick)
|
||||
CREATE_UNARY(silu)
|
||||
CREATE_UNARY(relu)
|
||||
CREATE_UNARY(sqr)
|
||||
CREATE_UNARY(sqrt)
|
||||
CREATE_UNARY(sin)
|
||||
CREATE_UNARY(cos)
|
||||
CREATE_UNARY(clamp)
|
||||
CREATE_UNARY(leaky_relu)
|
||||
CREATE_UNARY(xielu)
|
||||
CREATE_UNARY(neg)
|
||||
CREATE_UNARY(tanh)
|
||||
@@ -5097,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||
@@ -5314,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d, conv_transpose_2d
|
||||
// conv2d, conv_transpose_2d, conv3d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
|
||||
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
|
||||
@@ -5377,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
|
||||
};
|
||||
|
||||
// coopmat1 needs to store the output through shared memory, so check up front
|
||||
// whether it'll fit and disable it before applying coopmat1 parameters.
|
||||
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
|
||||
// layout. cm1 needs Csh for output, so check before applying cm1 params.
|
||||
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
|
||||
conv2d_use_cm1 = false;
|
||||
}
|
||||
@@ -5470,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#undef CREATE_CONV
|
||||
#undef CREATE_CONVS
|
||||
|
||||
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
|
||||
#define CREATE_CONV3D(type_suffix, spv_suffix) \
|
||||
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
|
||||
const vk_conv3d_pipeline_state &state = c.first; \
|
||||
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
|
||||
spec_constants_cpy.push_back(state.s0); \
|
||||
spec_constants_cpy.push_back(state.s1); \
|
||||
spec_constants_cpy.push_back(state.s2); \
|
||||
spec_constants_cpy.push_back(state.p0); \
|
||||
spec_constants_cpy.push_back(state.p1); \
|
||||
spec_constants_cpy.push_back(state.p2); \
|
||||
spec_constants_cpy.push_back(state.d0); \
|
||||
spec_constants_cpy.push_back(state.d1); \
|
||||
spec_constants_cpy.push_back(state.d2); \
|
||||
spec_constants_cpy.push_back(state.KW); \
|
||||
spec_constants_cpy.push_back(state.KH); \
|
||||
spec_constants_cpy.push_back(state.KD); \
|
||||
spec_constants_cpy.push_back(state.aligned); \
|
||||
spec_constants_cpy.push_back(conv2d_csh_store); \
|
||||
spec_constants_cpy.push_back(conv2d_WM); \
|
||||
spec_constants_cpy.push_back(conv2d_WN); \
|
||||
ggml_vk_create_pipeline( \
|
||||
device, c.second, "conv3d" #type_suffix, \
|
||||
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
|
||||
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
|
||||
}
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_CONV3D(_f32, _cm2)
|
||||
CREATE_CONV3D(_f16_f32, _cm2)
|
||||
} else
|
||||
#endif
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (conv2d_use_cm1) {
|
||||
CREATE_CONV3D(_f32, _cm1)
|
||||
CREATE_CONV3D(_f16_f32, _cm1)
|
||||
} else
|
||||
#endif
|
||||
if (conv2d_UNROLL) {
|
||||
CREATE_CONV3D(_f32, _unroll)
|
||||
CREATE_CONV3D(_f16_f32, _unroll)
|
||||
} else {
|
||||
CREATE_CONV3D(_f32, )
|
||||
CREATE_CONV3D(_f16_f32, )
|
||||
}
|
||||
#undef CREATE_CONV3D
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -10294,6 +10408,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_get_rows_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
@@ -10400,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQR:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqr_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQRT:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqrt_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LOG:
|
||||
@@ -10438,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_PAD:
|
||||
@@ -10807,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D:
|
||||
@@ -10885,6 +11010,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_3D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
|
||||
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
|
||||
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
|
||||
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
|
||||
|
||||
const uint32_t KW = (uint32_t)src0->ne[0];
|
||||
const uint32_t KH = (uint32_t)src0->ne[1];
|
||||
const uint32_t KD = (uint32_t)src0->ne[2];
|
||||
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
|
||||
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
|
||||
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
|
||||
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
|
||||
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
|
||||
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
|
||||
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
|
||||
|
||||
const uint32_t CRS = IC * KW * KH * KD;
|
||||
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
|
||||
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
|
||||
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
|
||||
const uint32_t aligned = ((OC % BS_K == 0) &&
|
||||
(CRS % BS_CRS == 0) &&
|
||||
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
|
||||
|
||||
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
|
||||
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
||||
auto it = pipelines->find(conv3d_pipeline_state);
|
||||
if (it != pipelines->end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
}
|
||||
}
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD1:
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f16;
|
||||
@@ -11135,6 +11315,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
@@ -11220,6 +11404,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
GGML_ABORT("invalid push constant type for CONV_2D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
|
||||
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
|
||||
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
|
||||
|
||||
elements = { pc.OC, NPQ_blocks, 1 };
|
||||
if (elements[1] > 512) {
|
||||
elements[2] = CEIL_DIV(elements[1], 512);
|
||||
elements[1] = 512;
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("invalid push constant type for CONV_3D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
@@ -11236,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11380,6 +11580,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
@@ -12087,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
|
||||
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -13118,6 +13335,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
vk_op_conv3d_push_constants p{};
|
||||
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
|
||||
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
|
||||
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
|
||||
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
|
||||
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
|
||||
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
|
||||
|
||||
p.IW = static_cast<uint32_t>(ne10);
|
||||
p.IH = static_cast<uint32_t>(ne11);
|
||||
p.ID = static_cast<uint32_t>(ne12);
|
||||
p.OW = static_cast<uint32_t>(ne0);
|
||||
p.OH = static_cast<uint32_t>(ne1);
|
||||
p.OD = static_cast<uint32_t>(ne2);
|
||||
|
||||
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
|
||||
// total input element count must fit in a uint32.
|
||||
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
|
||||
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
||||
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
||||
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
||||
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
vk_op_conv2d_dw_push_constants p{};
|
||||
p.ne = ggml_nelements(dst);
|
||||
@@ -13144,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_RUN_TESTS
|
||||
@@ -14247,6 +14512,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
@@ -14515,6 +14784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
|
||||
@@ -16964,6 +17237,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
switch (op->type) {
|
||||
@@ -17060,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -17084,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
@@ -17285,6 +17560,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op));
|
||||
}
|
||||
case GGML_OP_CONV_3D:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -18128,6 +18410,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
const int32_t d0 = tensor->op_params[4];
|
||||
const int32_t d1 = tensor->op_params[5];
|
||||
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
||||
} else if (tensor->op == GGML_OP_CONV_3D) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
const int32_t s2 = tensor->op_params[2];
|
||||
const int32_t p0 = tensor->op_params[3];
|
||||
const int32_t p1 = tensor->op_params[4];
|
||||
const int32_t p2 = tensor->op_params[5];
|
||||
const int32_t d0 = tensor->op_params[6];
|
||||
const int32_t d1 = tensor->op_params[7];
|
||||
const int32_t d2 = tensor->op_params[8];
|
||||
const int32_t IC = tensor->op_params[9];
|
||||
const int32_t N = tensor->op_params[10];
|
||||
const int32_t OC = tensor->op_params[11];
|
||||
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
|
||||
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
}
|
||||
@@ -0,0 +1,431 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#ifdef COOPMAT2
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, KD, IC*OC]
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
|
||||
|
||||
layout(binding = 2) writeonly buffer D {
|
||||
D_TYPE dst_data[];
|
||||
}; // dst - result: [OW, OH, OD, OC*N]
|
||||
|
||||
layout(push_constant) uniform parameter {
|
||||
// I/O channels, batch size
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
// Tensor spatial sizes: input, output
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
// Strides in elements
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
// fastdiv helper values
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
}
|
||||
|
||||
p;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
// Blocktile sizes
|
||||
layout(constant_id = 1) const uint BS_K = 128;
|
||||
layout(constant_id = 2) const uint BS_CRS = 16;
|
||||
layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||
// Thread-tile sizes
|
||||
layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint SHMEM_PAD = 4;
|
||||
// Stride, padding, dilation
|
||||
layout(constant_id = 6) const uint s0 = 1;
|
||||
layout(constant_id = 7) const uint s1 = 1;
|
||||
layout(constant_id = 8) const uint s2 = 1;
|
||||
layout(constant_id = 9) const uint p0 = 0;
|
||||
layout(constant_id = 10) const uint p1 = 0;
|
||||
layout(constant_id = 11) const uint p2 = 0;
|
||||
layout(constant_id = 12) const uint d0 = 1;
|
||||
layout(constant_id = 13) const uint d1 = 1;
|
||||
layout(constant_id = 14) const uint d2 = 1;
|
||||
// Kernel spatial sizes
|
||||
layout(constant_id = 15) const uint KW = 1;
|
||||
layout(constant_id = 16) const uint KH = 1;
|
||||
layout(constant_id = 17) const uint KD = 1;
|
||||
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
|
||||
layout(constant_id = 18) const uint aligned = 0;
|
||||
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
|
||||
layout(constant_id = 19) const uint csh_store = 0;
|
||||
|
||||
#ifdef COOPMAT
|
||||
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
|
||||
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
|
||||
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
|
||||
layout(constant_id = 20) const uint WM = 32;
|
||||
layout(constant_id = 21) const uint WN = 32;
|
||||
const uint TM = 16;
|
||||
const uint TN = 16;
|
||||
const uint TK = 16;
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
const uint warps_M = BS_K / WM;
|
||||
const uint warps_N = BS_NPQ / WN;
|
||||
#endif
|
||||
|
||||
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
|
||||
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
uint splitWork(uint work_size, uint block_size) {
|
||||
return (block_size + work_size - 1) / block_size;
|
||||
}
|
||||
|
||||
uint32_t K = p.OC;
|
||||
uint32_t CRS = p.IC * KD * KH * KW;
|
||||
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
|
||||
|
||||
// Number of blocktiles per input
|
||||
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
#define SHMEM_TYPE float16_t
|
||||
#else
|
||||
#define SHMEM_TYPE float
|
||||
#endif
|
||||
|
||||
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
||||
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
||||
|
||||
const uint32_t Ash_len = BS_K * Ash_stride;
|
||||
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
||||
|
||||
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
|
||||
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
|
||||
const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
// Threadtile sizes
|
||||
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
||||
|
||||
// Number of threadtiles per blocktile
|
||||
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
/*
|
||||
Compute
|
||||
KxCRS @ CRSxNPQ = K x NPQ
|
||||
K=OC
|
||||
C=IC
|
||||
D,R,S=KD,KH,KW
|
||||
Z,P,Q=OD,OH,OW
|
||||
*/
|
||||
|
||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
|
||||
|
||||
uint32_t T_y = tid / NT_NPQ;
|
||||
uint32_t T_x = tid % NT_NPQ;
|
||||
|
||||
uint32_t Ar = tid / BS_CRS;
|
||||
uint32_t Ac = tid % BS_CRS;
|
||||
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
||||
|
||||
uint32_t Br = tid / BS_NPQ;
|
||||
uint32_t Bc = tid % BS_NPQ;
|
||||
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
|
||||
const uint32_t KHKW = KH * KW;
|
||||
const uint32_t KDKHKW = KD * KHKW;
|
||||
ic = crs_idx / KDKHKW;
|
||||
uint32_t rem = crs_idx - ic * KDKHKW;
|
||||
kd = rem / KHKW;
|
||||
rem = rem - kd * KHKW;
|
||||
kh = rem / KW;
|
||||
kw = rem - kh * KW;
|
||||
}
|
||||
|
||||
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
|
||||
const uint32_t OWOH = p.OW * p.OH;
|
||||
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
|
||||
uint32_t rem = npq_idx - n * p.OD * OWOH;
|
||||
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
|
||||
rem = rem - od * OWOH;
|
||||
oh = fastdiv(rem, p.OWmp, p.OWL);
|
||||
ow = rem - oh * p.OW;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
#define ACC_TYPE float16_t
|
||||
|
||||
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
|
||||
{
|
||||
uint32_t K_idx = B_idx_K * BS_K + r;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
if (B_idx_NPQ * BS_NPQ >= NPQ) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
||||
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
||||
#elif defined(COOPMAT)
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
||||
}
|
||||
const uint warp_r = gl_SubgroupID / warps_N;
|
||||
const uint warp_c = gl_SubgroupID % warps_N;
|
||||
#else
|
||||
float regC[TS_K][TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = 0.0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
/* Advance block in CRS dim */
|
||||
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
|
||||
uint32_t IC_idx_a;
|
||||
uint32_t KD_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
uint32_t KW_idx_a;
|
||||
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
|
||||
if (aligned == 0) {
|
||||
knl_idx = min(knl_idx, K * CRS - 1);
|
||||
}
|
||||
float val = knl_data[knl_idx];
|
||||
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
|
||||
val = 0.0;
|
||||
}
|
||||
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
||||
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
||||
uint32_t B_lx = Bc;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
|
||||
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
|
||||
uint32_t IC_idx_b;
|
||||
uint32_t KD_idx_b;
|
||||
uint32_t KH_idx_b;
|
||||
uint32_t KW_idx_b;
|
||||
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
|
||||
|
||||
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
|
||||
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
|
||||
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
|
||||
|
||||
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
|
||||
// skip clamp when address can't go OOB
|
||||
if (aligned == 0 || !dhw_in_bounds) {
|
||||
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
|
||||
}
|
||||
float val = src_data[src_idx];
|
||||
bool oob = false;
|
||||
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
|
||||
oob = true;
|
||||
}
|
||||
// also catches lower-bound underflow (idx wraps to 0x80000000+)
|
||||
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
|
||||
oob = true;
|
||||
}
|
||||
if (oob) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
barrier();
|
||||
#ifdef COOPMAT2
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
|
||||
|
||||
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
matC = coopMatMulAdd(matA, matB, matC);
|
||||
#elif defined(COOPMAT)
|
||||
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
|
||||
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
|
||||
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
||||
float regA[TS_K];
|
||||
float regB[TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
||||
}
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
||||
}
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
/* Save C* */
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
|
||||
#ifdef COOPMAT
|
||||
const bool use_staged_store = true;
|
||||
#else
|
||||
const bool use_staged_store = (csh_store != 0);
|
||||
#endif
|
||||
if (use_staged_store) {
|
||||
#ifdef COOPMAT
|
||||
// cm1: each subgroup stores its fragment grid into its Csh slot
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
#endif
|
||||
barrier();
|
||||
|
||||
// cooperative shmem->global: WG threads spread across BS_NPQ (the
|
||||
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
|
||||
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
|
||||
const uint32_t store_iters = BS_K / store_rows_per_iter;
|
||||
const uint32_t k_thread_offset = tid / BS_NPQ;
|
||||
const uint32_t npq_thread = tid % BS_NPQ;
|
||||
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
|
||||
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
|
||||
uint32_t K_idx = B_idx_K * BS_K + k_local;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef COOPMAT2
|
||||
else {
|
||||
coopMatPerElementNV(matC, matC, perElemOpStore);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
||||
@@ -463,6 +463,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
|
||||
@@ -352,6 +352,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
if (col >= p.ne20) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
|
||||
float sum = 0.0f;
|
||||
for (uint i = 0; i < p.ne10; ++i) {
|
||||
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
|
||||
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
|
||||
}
|
||||
}
|
||||
@@ -14,16 +14,13 @@ void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -39,6 +36,6 @@ void main() {
|
||||
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float val = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
|
||||
}
|
||||
@@ -38,17 +38,7 @@
|
||||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
|
||||
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_A 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_A 1
|
||||
#endif
|
||||
#if !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_B 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_B 1
|
||||
#endif
|
||||
layout (constant_id = 11) const uint ALIGNED = 0;
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||
@@ -57,6 +47,13 @@
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
|
||||
#elif defined(DATA_A_F16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
|
||||
#elif defined(DATA_A_BF16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -194,13 +192,23 @@ void main() {
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
|
||||
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
|
||||
#else
|
||||
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
|
||||
const uint LOAD_VEC_BATCH_A = 1;
|
||||
#endif
|
||||
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
|
||||
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
@@ -239,15 +247,15 @@ void main() {
|
||||
|
||||
uint pos_a =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#else
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#endif
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_b = 0;
|
||||
#else
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
@@ -287,8 +295,8 @@ void main() {
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
pos_a += BK / LOAD_VEC_A_EFF;
|
||||
pos_b += BK / LOAD_VEC_B_EFF;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
|
||||
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||
layout (constant_id = 5) const uint ALIGNED = 0;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
layout (constant_id = 6) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
@@ -297,12 +298,12 @@ void main() {
|
||||
|
||||
// Hint to the compiler that values are aligned (want 16B alignment).
|
||||
// Quants are always block-aligned, no alignment needed.
|
||||
#if ALIGNED
|
||||
if (ALIGNED != 0) {
|
||||
#if QUANT_K == 1
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
}
|
||||
|
||||
// Create layouts for both clamped and unclamped accesses
|
||||
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
||||
|
||||
@@ -1,50 +1,57 @@
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
|
||||
data_a_scalar[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared vec2 sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = vec2(0.0f, 0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid].x += xi;
|
||||
sum[tid].y += xi * xi;
|
||||
}
|
||||
@@ -34,11 +34,11 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
const float mean = sum[0].x / p.KX;
|
||||
const float var = sum[0].y / p.KX - mean * mean;
|
||||
const float mean = sum[0].x / p.ne00;
|
||||
const float var = sum[0].y / p.ne00 - mean * mean;
|
||||
const float inv_std = inversesqrt(var + p.param1);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||
}
|
||||
@@ -17,6 +17,30 @@ float op_neg(float x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
float op_sqr(float x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
float op_sqrt(float x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
float op_sin(float x) {
|
||||
return sin(x);
|
||||
}
|
||||
|
||||
float op_cos(float x) {
|
||||
return cos(x);
|
||||
}
|
||||
|
||||
float op_clamp(float x) {
|
||||
return clamp(x, p.param1, p.param2);
|
||||
}
|
||||
|
||||
float op_leaky_relu(float x) {
|
||||
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
|
||||
}
|
||||
|
||||
float op_step(float x) {
|
||||
return x >= 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
@@ -539,11 +539,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
@@ -565,8 +563,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
#endif
|
||||
{
|
||||
if (!dot2) {
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -583,8 +580,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
@@ -597,13 +592,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
@@ -850,21 +843,12 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
|
||||
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
|
||||
@@ -891,6 +875,18 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
|
||||
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
|
||||
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
|
||||
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
|
||||
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
|
||||
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
|
||||
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
|
||||
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
|
||||
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
|
||||
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
|
||||
@@ -948,7 +944,6 @@ void process_shaders() {
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -1060,6 +1055,31 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto unroll : {false, true}) {
|
||||
for (auto a_f16 : {false, true}) {
|
||||
std::map<std::string, std::string> defines = {
|
||||
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
|
||||
{"UNROLL", unroll ? "[[unroll]]" : ""},
|
||||
};
|
||||
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
|
||||
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm2_defines = defines;
|
||||
cm2_defines["COOPMAT2"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm1_defines = defines;
|
||||
cm1_defines["COOPMAT"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
|
||||
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
vectorized == other.vectorized;
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
@@ -3788,7 +3790,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
||||
ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
@@ -3800,17 +3802,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
options.nextInChain = &adapterTogglesDesc;
|
||||
#endif
|
||||
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
ctx->webgpu_global_ctx->adapter = std::move(adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
instance.WaitAny(instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
}
|
||||
|
||||
static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter);
|
||||
GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
|
||||
|
||||
ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
|
||||
@@ -4265,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||
@@ -4543,20 +4548,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
|
||||
// Probe for adapter support
|
||||
wgpu::Adapter adapter;
|
||||
if (ctx->webgpu_global_ctx->instance != nullptr) {
|
||||
wgpu::RequestAdapterOptions options = {};
|
||||
|
||||
// probe for adapter support
|
||||
ctx->webgpu_global_ctx->instance.WaitAny(
|
||||
ctx->webgpu_global_ctx->instance.RequestAdapter(
|
||||
&options, wgpu::CallbackMode::AllowSpontaneous,
|
||||
[&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) {
|
||||
if (status != wgpu::RequestAdapterStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
|
||||
return;
|
||||
}
|
||||
adapter = std::move(_adapter);
|
||||
}),
|
||||
UINT64_MAX);
|
||||
ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter);
|
||||
}
|
||||
|
||||
// WebGPU backend requires f16 support and, on native, implicit device synchronization.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user