mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 22:57:41 +02:00
Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8be759e6f7 | |||
| 894bb27af3 | |||
| fb401045cc | |||
| 51eae8cfca | |||
| 1191758c5d | |||
| 00139b660b | |||
| ef9c13d4c2 | |||
| 88636e178f | |||
| ac4105d68b | |||
| be4a6a63eb | |||
| 72a9269172 | |||
| 92e854ab83 | |||
| c5606364b2 | |||
| 0eb874d374 | |||
| 75ad0b23ed | |||
| c926ad0985 | |||
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 | |||
| dec5ca5577 | |||
| 9c0ac887f3 | |||
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac | |||
| 0ef6f06d55 | |||
| 52b3df0023 | |||
| 7c082bc417 | |||
| bddfd2b113 | |||
| 0d135df48c | |||
| bf533823cd | |||
| 2f89acc2bc | |||
| bfa3219177 | |||
| d6d899580d | |||
| 8a118ee86c | |||
| d789527482 | |||
| 063d9c156e | |||
| c57607016a | |||
| 4a80943174 | |||
| 84de01a1f1 | |||
| 75f460ac28 | |||
| 8452824611 | |||
| e27f308597 | |||
| 67e9fd3b74 |
@@ -4,20 +4,6 @@ ARG BUILD_DATE=N/A
|
||||
ARG APP_VERSION=N/A
|
||||
ARG APP_REVISION=N/A
|
||||
|
||||
ARG NODE_VERSION=24
|
||||
|
||||
FROM docker.io/node:$NODE_VERSION AS web
|
||||
|
||||
ARG APP_VERSION
|
||||
|
||||
WORKDIR /app/tools/ui
|
||||
|
||||
COPY tools/ui/package.json tools/ui/package-lock.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY tools/ui/ ./
|
||||
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
|
||||
|
||||
### Build Llama.cpp stage
|
||||
FROM docker.io/gcc:${GCC_VERSION} AS build
|
||||
|
||||
@@ -34,8 +20,6 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
|
||||
COPY --from=web /app/tools/ui/dist tools/ui/dist
|
||||
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
--mount=type=cache,target=/app/build \
|
||||
cmake -S . -B build -G Ninja \
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
build*/
|
||||
|
||||
tools/ui/node_modules/
|
||||
tools/ui/dist/
|
||||
|
||||
models/*
|
||||
|
||||
|
||||
@@ -58,6 +58,13 @@ jobs:
|
||||
git tag ${{ steps.srctag.outputs.name }} || exit 0
|
||||
git push origin ${{ steps.srctag.outputs.name }} || exit 0
|
||||
|
||||
build_ui:
|
||||
name: Build UI
|
||||
needs: create_tag
|
||||
uses: ./.github/workflows/ui-build.yml
|
||||
with:
|
||||
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
prepare_matrices:
|
||||
name: Prepare Docker matrices
|
||||
runs-on: ubuntu-24.04
|
||||
@@ -79,7 +86,7 @@ jobs:
|
||||
[
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
@@ -135,7 +142,7 @@ jobs:
|
||||
|
||||
push_to_registry:
|
||||
name: Push Docker image to Docker Registry
|
||||
needs: [prepare_matrices, create_tag]
|
||||
needs: [prepare_matrices, create_tag, build_ui]
|
||||
|
||||
runs-on: ${{ matrix.config.runs_on }}
|
||||
strategy:
|
||||
@@ -150,6 +157,13 @@ jobs:
|
||||
fetch-depth: 0
|
||||
ref: ${{ needs.create_tag.outputs.source_tag }}
|
||||
|
||||
- name: Download prebuilt UI
|
||||
if: ${{ matrix.config.prebuilt_ui == true }}
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
|
||||
with:
|
||||
name: ui-build
|
||||
path: tools/ui/dist
|
||||
|
||||
- name: Set up QEMU
|
||||
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
|
||||
|
||||
@@ -1627,6 +1627,7 @@ jobs:
|
||||
**Windows:**
|
||||
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
|
||||
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
|
||||
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
|
||||
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
|
||||
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
|
||||
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
|
||||
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
|
||||
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
|
||||
|
||||
@@ -80,8 +80,6 @@ add_library(${TARGET}
|
||||
http.h
|
||||
imatrix-loader.cpp
|
||||
imatrix-loader.h
|
||||
json-partial.cpp
|
||||
json-partial.h
|
||||
json-schema-to-grammar.cpp
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
|
||||
+17
-8
@@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
// TODO @ngxson : refactor this into a new common_model_download_context
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
@@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) {
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
@@ -407,6 +409,11 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.preset_only = handle_params.preset_only;
|
||||
|
||||
if (handle_params.callback) {
|
||||
opts.callback = handle_params.callback;
|
||||
}
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
@@ -584,17 +591,19 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
common_params_handle_models(params, ctx_arg.ex, {});
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
@@ -924,8 +933,8 @@ static utf8_argv make_utf8_argv() {
|
||||
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
#ifdef _WIN32
|
||||
auto utf8 = make_utf8_argv();
|
||||
if (!utf8.ptrs.empty()) {
|
||||
argc = static_cast<int>(utf8.buf.size());
|
||||
// repair argv only when it matches the process command line
|
||||
if (static_cast<int>(utf8.buf.size()) == argc) {
|
||||
argv = utf8.ptrs.data();
|
||||
}
|
||||
#endif
|
||||
@@ -2897,7 +2906,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
add_opt(common_arg(
|
||||
add_opt(common_arg(
|
||||
{"-ag", "--agent"},
|
||||
{"-no-ag", "--no-agent"},
|
||||
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
|
||||
|
||||
+10
-1
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
@@ -129,11 +130,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_params_handle_models_params {
|
||||
common_download_callback * callback = nullptr;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
};
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
bool common_params_handle_models(
|
||||
common_params & params,
|
||||
llama_example curr_ex,
|
||||
const common_params_handle_models_params & handle_params);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
p.ac(p.tool_arg_string_value(until_suffix) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
|
||||
(p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)))));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
+103
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+1
-1
@@ -609,7 +609,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
+3
-1
@@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool preset_only = opts.preset_only;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
@@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model &
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else {
|
||||
} else if (!preset_only) {
|
||||
// only add other files if we're NOT in preset-only mode (normal run, non-router)
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ struct common_download_opts {
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
bool preset_only = false; // if true, only check & download remote preset (for router mode)
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
|
||||
+89
-46
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
|
||||
const size_t expected_count = this_args.size();
|
||||
const size_t input_count = args.count();
|
||||
|
||||
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this_args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this_args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
value macro_statement::execute_impl(context & ctx) {
|
||||
if (!is_stmt<identifier>(this->name)) {
|
||||
throw std::runtime_error("Macro name must be an identifier");
|
||||
}
|
||||
std::string name = cast_stmt<identifier>(this->name)->val;
|
||||
|
||||
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
|
||||
size_t expected_count = this->args.size();
|
||||
size_t input_count = args.count();
|
||||
const func_handler func = [this, name](const func_args & args) -> value {
|
||||
context macro_ctx(args.ctx); // new scope for macro execution
|
||||
|
||||
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
|
||||
context macro_ctx(ctx); // new scope for macro execution
|
||||
|
||||
// bind parameters
|
||||
for (size_t i = 0; i < expected_count; ++i) {
|
||||
if (i < input_count) {
|
||||
if (is_stmt<identifier>(this->args[i])) {
|
||||
// normal parameter
|
||||
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
|
||||
// default argument used as normal parameter
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
value param_value = args.get_kwarg_or_pos(param_name, i);
|
||||
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
|
||||
macro_ctx.set_val(param_name, param_value);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
|
||||
}
|
||||
} else {
|
||||
auto & default_arg = this->args[i];
|
||||
if (is_stmt<keyword_argument_expression>(default_arg)) {
|
||||
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
|
||||
if (!is_stmt<identifier>(kwarg->key)) {
|
||||
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
|
||||
}
|
||||
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
|
||||
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
|
||||
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
|
||||
} else {
|
||||
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
|
||||
}
|
||||
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
|
||||
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
|
||||
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
|
||||
}
|
||||
}
|
||||
bind_parameters(name, this->args, args, macro_ctx);
|
||||
|
||||
// execute macro body
|
||||
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
|
||||
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
|
||||
value call_statement::execute_impl(context & ctx) {
|
||||
auto call_expr = cast_stmt<call_expression>(this->call);
|
||||
if (!call_expr) {
|
||||
throw std::runtime_error("Call statement requires a valid call expression");
|
||||
}
|
||||
|
||||
value callee_val = call_expr->callee->execute(ctx);
|
||||
if (!is_val<value_func>(callee_val)) {
|
||||
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
|
||||
}
|
||||
auto * callee_func = cast_val<value_func>(callee_val);
|
||||
|
||||
context caller_ctx(ctx); // new scope for caller execution
|
||||
|
||||
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
|
||||
context block_ctx(caller_ctx); // new scope for block execution
|
||||
|
||||
bind_parameters("caller", this->caller_args, args, block_ctx);
|
||||
|
||||
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
|
||||
auto res = exec_statements(this->body, block_ctx);
|
||||
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
|
||||
return res;
|
||||
};
|
||||
|
||||
context call_ctx(ctx);
|
||||
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
|
||||
|
||||
func_args args(call_ctx);
|
||||
|
||||
for (const auto & arg_expr : call_expr->args) {
|
||||
auto arg_val = arg_expr->execute(ctx);
|
||||
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
|
||||
args.push_back(arg_val);
|
||||
}
|
||||
|
||||
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
|
||||
return callee_func->invoke(args);
|
||||
}
|
||||
|
||||
value member_expression::execute_impl(context & ctx) {
|
||||
value object = this->object->execute(ctx);
|
||||
|
||||
|
||||
@@ -552,6 +552,7 @@ struct call_statement : public statement {
|
||||
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
#include "json-partial.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||
};
|
||||
|
||||
struct common_json_stack_element {
|
||||
common_json_stack_element_type type;
|
||||
std::string key;
|
||||
};
|
||||
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
std::string::const_iterator it = input.begin();
|
||||
const auto end = input.end();
|
||||
return common_json_parse(it, end, healing_marker, out);
|
||||
}
|
||||
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
std::string last_token;
|
||||
std::string exception_message;
|
||||
std::vector<common_json_stack_element> stack;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
this->last_token = last_token;
|
||||
this->exception_message = ex.what();
|
||||
return false;
|
||||
}
|
||||
void close_value() {
|
||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||
stack.pop_back();
|
||||
}
|
||||
}
|
||||
bool null() override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool boolean(bool) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_integer(number_integer_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool string(string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool binary(binary_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool start_object(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_object() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool key(string_t & key) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||
return true;
|
||||
}
|
||||
bool start_array(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_array() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
auto start = it;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
if (err_loc.found_error) {
|
||||
it = start;
|
||||
auto temptative_end = it + err_loc.position;
|
||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||
|
||||
auto input = std::string(it, temptative_end);
|
||||
try {
|
||||
out.json = json::parse(input);
|
||||
// out.json = json::parse(it, temptative_end);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
auto _ = json::parse(str); // NOLINT
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||
if (last_non_sp_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
auto last_non_sp_char = str[last_non_sp_pos];
|
||||
// Used to detect stops on a number, which may not be complete.
|
||||
auto was_maybe_number = [&]() {
|
||||
if (!str.empty() && std::isspace(str.back())) {
|
||||
return false;
|
||||
}
|
||||
return std::isdigit(last_non_sp_char) ||
|
||||
last_non_sp_char == '.' ||
|
||||
last_non_sp_char == 'e' ||
|
||||
last_non_sp_char == 'E' ||
|
||||
last_non_sp_char == '-';
|
||||
};
|
||||
|
||||
std::string closing;
|
||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||
auto & el = err_loc.stack[i - 1];
|
||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
closing += "}";
|
||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
closing += "]";
|
||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
throw std::runtime_error("Unexpected stack element type");
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
// We're inside an object value
|
||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an object value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + ": 1" + closing)) {
|
||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||
// Was about to create an object
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an object value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to opening : for object value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an array value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an array value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of("[,");
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to last [ or , for array value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\": 1" + closing)) {
|
||||
// Was inside an object key string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
it = end;
|
||||
return true;
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
// Raw marker.
|
||||
std::string marker;
|
||||
|
||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||
std::string json_dump_marker;
|
||||
};
|
||||
|
||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||
struct common_json {
|
||||
nlohmann::ordered_json json;
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
};
|
||||
|
||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||
//
|
||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||
//
|
||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
|
||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
@@ -233,27 +233,27 @@ struct BuiltinRule {
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
||||
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
||||
{"boolean", {"(\"true\" | \"false\")", {}}},
|
||||
{"decimal-part", {"[0-9]{1,16}", {}}},
|
||||
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
|
||||
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
|
||||
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
|
||||
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
||||
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
|
||||
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
|
||||
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
|
||||
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
||||
{"null", {"\"null\" space", {}}},
|
||||
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
|
||||
{"null", {"\"null\"", {}}},
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
||||
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
||||
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
|
||||
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
|
||||
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
|
||||
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
@@ -551,16 +551,16 @@ private:
|
||||
}
|
||||
return join_seq();
|
||||
};
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
||||
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
|
||||
}
|
||||
|
||||
/*
|
||||
Returns a rule that matches a JSON string that is none of the provided strings
|
||||
|
||||
not_strings({"a"})
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] char+ | [^"a] char* )? ["]
|
||||
not_strings({"and", "also"})
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
||||
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
|
||||
*/
|
||||
std::string _not_strings(const std::vector<std::string> & strings) {
|
||||
|
||||
@@ -619,7 +619,7 @@ private:
|
||||
if (!trie.is_end_of_string) {
|
||||
out << "?";
|
||||
}
|
||||
out << " [\"] space";
|
||||
out << " [\"]";
|
||||
return out.str();
|
||||
}
|
||||
|
||||
@@ -725,7 +725,7 @@ private:
|
||||
rule += " )?";
|
||||
}
|
||||
|
||||
rule += " \"}\" space";
|
||||
rule += " space \"}\"";
|
||||
|
||||
return rule;
|
||||
}
|
||||
@@ -858,14 +858,14 @@ public:
|
||||
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
||||
}
|
||||
if (schema.contains("const")) {
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
||||
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
|
||||
}
|
||||
if (schema.contains("enum")) {
|
||||
std::vector<std::string> enum_values;
|
||||
for (const auto & v : schema["enum"]) {
|
||||
enum_values.push_back(_generate_constant_rule(v));
|
||||
}
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "object")
|
||||
&& (schema.contains("properties") ||
|
||||
@@ -933,7 +933,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (!enum_intersection.empty()) {
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
|
||||
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
|
||||
}
|
||||
}
|
||||
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
||||
@@ -948,7 +948,7 @@ public:
|
||||
}
|
||||
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
||||
}
|
||||
rule += " \"]\" space";
|
||||
rule += " space \"]\"";
|
||||
return _add_rule(rule_name, rule);
|
||||
}
|
||||
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
||||
@@ -956,7 +956,7 @@ public:
|
||||
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
||||
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
||||
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
||||
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
|
||||
}
|
||||
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
||||
return _visit_pattern(schema["pattern"], rule_name);
|
||||
@@ -972,7 +972,7 @@ public:
|
||||
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
||||
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
||||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
|
||||
}
|
||||
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
@@ -990,7 +990,7 @@ public:
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
build_min_max_int(min_value, max_value, out);
|
||||
out << ") space";
|
||||
out << ")";
|
||||
return _add_rule(rule_name, out.str());
|
||||
}
|
||||
if (schema.empty() || schema_type == "object") {
|
||||
|
||||
+202
-89
@@ -6,13 +6,14 @@
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <initializer_list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
|
||||
// Trick to catch missing branches
|
||||
template <typename T>
|
||||
@@ -88,40 +89,7 @@ struct trie {
|
||||
return match_result{match_result::NO_MATCH};
|
||||
}
|
||||
|
||||
struct prefix_and_next {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<uint32_t> next_chars;
|
||||
};
|
||||
|
||||
std::vector<prefix_and_next> collect_prefix_and_next() {
|
||||
std::vector<uint32_t> prefix;
|
||||
std::vector<prefix_and_next> result;
|
||||
collect_prefix_and_next(0, prefix, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
|
||||
if (!nodes[index].is_word) {
|
||||
if (!nodes[index].children.empty()) {
|
||||
std::vector<uint32_t> chars;
|
||||
chars.reserve(nodes[index].children.size());
|
||||
for (const auto & p : nodes[index].children) {
|
||||
chars.push_back(p.first);
|
||||
}
|
||||
out.emplace_back(prefix_and_next{prefix, chars});
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & p : nodes[index].children) {
|
||||
uint32_t ch = p.first;
|
||||
auto child = p.second;
|
||||
prefix.push_back(ch);
|
||||
collect_prefix_and_next(child, prefix, out);
|
||||
prefix.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
size_t create_node() {
|
||||
size_t index = nodes.size();
|
||||
nodes.emplace_back();
|
||||
@@ -153,6 +121,65 @@ struct trie {
|
||||
}
|
||||
};
|
||||
|
||||
// Aho-Corasick automaton
|
||||
struct aho_corasick {
|
||||
trie t;
|
||||
std::vector<size_t> fail; // failure links
|
||||
std::vector<size_t> order; // states in BFS order
|
||||
std::vector<bool> terminal; // match states (directly or via a suffix link)
|
||||
std::set<uint32_t> alphabet; // every character with a transition
|
||||
|
||||
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
|
||||
const auto & nodes = t.nodes;
|
||||
const size_t n = nodes.size();
|
||||
|
||||
fail.assign(n, 0);
|
||||
order.reserve(n);
|
||||
|
||||
std::deque<size_t> queue{ 0 };
|
||||
while (!queue.empty()) {
|
||||
size_t u = queue.front();
|
||||
queue.pop_front();
|
||||
order.push_back(u);
|
||||
for (const auto & [ch, v] : nodes[u].children) {
|
||||
if (u != 0) {
|
||||
size_t f = fail[u];
|
||||
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
|
||||
f = fail[f];
|
||||
}
|
||||
auto it = nodes[f].children.find(ch);
|
||||
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
|
||||
}
|
||||
queue.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
terminal.assign(n, false);
|
||||
for (size_t u : order) {
|
||||
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
|
||||
}
|
||||
|
||||
for (const auto & node : nodes) {
|
||||
for (const auto & [ch, v] : node.children) {
|
||||
alphabet.insert(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_states() const { return t.nodes.size(); }
|
||||
bool is_terminal(size_t s) const { return terminal[s]; }
|
||||
|
||||
// follow failure links until a transition on `ch` exists.
|
||||
size_t next(size_t state, uint32_t ch) const {
|
||||
const auto & nodes = t.nodes;
|
||||
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
|
||||
state = fail[state];
|
||||
}
|
||||
auto it = nodes[state].children.find(ch);
|
||||
return it != nodes[state].children.end() ? it->second : 0;
|
||||
}
|
||||
};
|
||||
|
||||
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
|
||||
if (pos + hex_count > str.length()) {
|
||||
return {0, 0};
|
||||
@@ -894,6 +921,10 @@ struct parser_executor {
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump(common_peg_parser_id id) const {
|
||||
std::unordered_set<common_peg_parser_id> visited;
|
||||
std::set<common_peg_parser_id> visited;
|
||||
return dump_impl(id, visited);
|
||||
}
|
||||
|
||||
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
|
||||
std::unordered_set<common_peg_parser_id> & visited) const {
|
||||
std::set<common_peg_parser_id> & visited) const {
|
||||
// Check for cycles
|
||||
if (visited.count(id)) {
|
||||
return "[cycle]";
|
||||
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
|
||||
common_peg_parser common_peg_parser_builder::json_array() {
|
||||
return rule("json-array", [this]() {
|
||||
auto ws = space();
|
||||
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
|
||||
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
|
||||
return sequence({
|
||||
literal("["),
|
||||
ws,
|
||||
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
|
||||
if (delimiters.empty()) {
|
||||
throw std::runtime_error("ac parser requires at least one delimiter");
|
||||
}
|
||||
return add(common_peg_ac_parser{p, delimiters});
|
||||
}
|
||||
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
|
||||
trie matcher(strings);
|
||||
auto pieces = matcher.collect_prefix_and_next();
|
||||
|
||||
std::string pattern;
|
||||
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
|
||||
for (size_t i = 0; i < pieces.size(); ++i) {
|
||||
if (i > 0) {
|
||||
pattern += " | ";
|
||||
}
|
||||
|
||||
const auto & pre = pieces[i].prefix;
|
||||
const auto & chars = pieces[i].next_chars;
|
||||
|
||||
std::string cls;
|
||||
cls.reserve(chars.size());
|
||||
for (uint32_t ch : chars) {
|
||||
cls += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
if (!pre.empty()) {
|
||||
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
|
||||
pattern += pre_literal + " [^" + cls + "]";
|
||||
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
|
||||
// char, so the repetition alone cannot match a value that *ends* on a proper
|
||||
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
|
||||
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
|
||||
// values, so without this the grammar would reject input the parser accepts.
|
||||
// Allow the value to terminate on any proper prefix as an optional tail.
|
||||
// This makes the grammar a slight superset of the runtime language (a value
|
||||
// may end on the longest prefix, which greedy first-match would not itself
|
||||
// produce); harmless for constrained generation, which only needs to admit
|
||||
// every runtime-valid string.
|
||||
if (!trailing.empty()) {
|
||||
trailing += " | ";
|
||||
}
|
||||
trailing += pre_literal;
|
||||
} else {
|
||||
pattern += "[^" + cls + "]";
|
||||
}
|
||||
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
|
||||
std::string result = "(" + pattern + ")*";
|
||||
if (!trailing.empty()) {
|
||||
result += " (" + trailing + ")?";
|
||||
}
|
||||
return result;
|
||||
return s + "]";
|
||||
}
|
||||
|
||||
static std::unordered_set<std::string> collect_reachable_rules(
|
||||
static std::string gbnf_ac_grammar(
|
||||
const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings,
|
||||
const std::function<std::string(const std::vector<uint32_t> &,
|
||||
const std::map<size_t, std::vector<uint32_t>> &,
|
||||
const std::vector<uint32_t> &,
|
||||
const std::function<std::string(size_t)> &)> & build_rule) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
if (s == 0) {
|
||||
return prefix;
|
||||
}
|
||||
std::string num = std::to_string(s);
|
||||
num = num.size() == 1 ? ("0" + num) : num;
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> completing; // chars that complete a delimiter
|
||||
std::vector<uint32_t> specific; // chars with an explicit transition
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
completing.push_back(c);
|
||||
specific.push_back(c);
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
specific.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
|
||||
}
|
||||
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches the empty string so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & /*completing*/,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
// every state is accepting and completing chars get no
|
||||
// alternative, so a forbidden string can never be matched
|
||||
std::string rhs = "|";
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
|
||||
return rhs;
|
||||
});
|
||||
}
|
||||
|
||||
// GBNF grammar matching everything up to and including the first occurrence of
|
||||
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & completing,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
std::vector<std::string> alts;
|
||||
if (!completing.empty()) {
|
||||
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
|
||||
}
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
|
||||
}
|
||||
// every other character keeps scanning from the start state
|
||||
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
|
||||
return string_join(alts, " | ");
|
||||
});
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
) {
|
||||
std::unordered_set<std::string> reachable;
|
||||
std::unordered_set<std::string> visited;
|
||||
std::set<std::string> reachable;
|
||||
std::set<std::string> visited;
|
||||
|
||||
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
|
||||
const auto & parser = arena.get(id);
|
||||
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
if (p.delimiters.empty()) {
|
||||
return ".*";
|
||||
}
|
||||
return gbnf_excluding_pattern(p.delimiters);
|
||||
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
|
||||
if (schema_delegates(p)) {
|
||||
return to_gbnf(p.child);
|
||||
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
};
|
||||
|
||||
// Collect reachable rules
|
||||
std::unordered_set<std::string> reachable_rules;
|
||||
std::set<std::string> reachable_rules;
|
||||
|
||||
if (lazy) {
|
||||
// Collect rules reachable from trigger rules
|
||||
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "ac") {
|
||||
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
|
||||
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
|
||||
}
|
||||
return common_peg_ac_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["delimiters"].get<std::vector<std::string>>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+16
-3
@@ -3,8 +3,8 @@
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <functional>
|
||||
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
struct common_peg_ac_parser {
|
||||
common_peg_parser_id child;
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
common_peg_gbnf_parser,
|
||||
common_peg_ac_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -335,7 +341,7 @@ class common_peg_arena {
|
||||
friend class common_peg_parser_builder;
|
||||
|
||||
private:
|
||||
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
|
||||
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
|
||||
|
||||
common_peg_parser_id add_parser(common_peg_parser_variant parser);
|
||||
void add_rule(const std::string & name, common_peg_parser_id id);
|
||||
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
|
||||
// automaton of `delimiters`, matching everything up to and including the
|
||||
// first delimiter. Parsing delegates entirely to the child, which is
|
||||
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+102
-35
@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int32_t n_embd = 0;
|
||||
|
||||
bool is_mem_shared = false;
|
||||
// One MTP draft driver, three modes (set once in the ctor):
|
||||
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
|
||||
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
|
||||
// neither (qwen35 / qwen35moe): a single trained MTP head.
|
||||
int32_t n_mtp_layers = 1;
|
||||
bool is_mem_shared = false; // gemma4
|
||||
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
|
||||
|
||||
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
|
||||
// The last h-row of one process() call needs the first token of the NEXT
|
||||
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
std::vector<std::vector<float>> verify_h;
|
||||
std::vector<int32_t> verify_h_rows;
|
||||
|
||||
// Per-seq draft length from the last draft() call, used in accept() to
|
||||
// roll back ctx_dft's recurrent state past the AR draft's redundant
|
||||
// pre-advancement before process() mirrored the verify batch.
|
||||
std::vector<uint16_t> last_n_drafted;
|
||||
std::vector<int> i_last;
|
||||
std::vector<std::vector<float>> chain_h;
|
||||
|
||||
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
|
||||
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
|
||||
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
|
||||
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
|
||||
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
|
||||
|
||||
if (chain_heads) {
|
||||
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
|
||||
|
||||
chain_h.assign(n_seq, {});
|
||||
for (auto & c : chain_h) {
|
||||
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
|
||||
|
||||
i_last.assign(n_seq, -1);
|
||||
i_batch_beg.assign(n_seq, -1);
|
||||
i_batch_end.assign(n_seq, -1);
|
||||
|
||||
verify_h.assign(n_seq, {});
|
||||
verify_h_rows.assign(n_seq, 0);
|
||||
|
||||
last_n_drafted.assign(n_seq, 0);
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_mtp() override {
|
||||
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
bool ok = true;
|
||||
for (int head = 0; head < n_mtp_layers; ++head) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, head);
|
||||
}
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
int n_drafting = 0;
|
||||
std::vector<bool> drafting(n_seq);
|
||||
|
||||
const float * h_row = nullptr;
|
||||
const size_t row_bytes = (size_t) n_embd * sizeof(float);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
|
||||
|
||||
h_row = pending_h[seq_id].data();
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
}
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
if (chain_heads) {
|
||||
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
|
||||
}
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
|
||||
while (n_drafting > 0) {
|
||||
int i_batch = 0;
|
||||
// each step decodes under a different head, i.e. a different decoder layer, and
|
||||
// KV is per layer. process() filled this layer's KV only for positions < n_past
|
||||
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
|
||||
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
|
||||
// and select head i so it rebuilds its own layer's KV there; decoding just the
|
||||
// latest token would leave its attention reading cells only another head wrote.
|
||||
if (chain_heads) {
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (drafting[seq_id]) {
|
||||
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
|
||||
}
|
||||
}
|
||||
llama_set_nextn_layer_offset(ctx_dft, i);
|
||||
}
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
// rebuild the batch for the next step: the growing-KV paths re-add only the
|
||||
// new token (the KV already holds the prefix), while chained heads re-add the
|
||||
// whole prefix at the next head. dropped sequences are simply not re-added.
|
||||
common_batch_clear(batch);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, i_batch, true);
|
||||
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
|
||||
++i_batch;
|
||||
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
|
||||
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (is_mem_shared) {
|
||||
if (chain_heads) {
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
|
||||
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
|
||||
|
||||
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
|
||||
for (int t = 0; t < n_rows; ++t) {
|
||||
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
|
||||
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
|
||||
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
|
||||
}
|
||||
} else if (is_mem_shared) {
|
||||
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
|
||||
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
|
||||
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
} else {
|
||||
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
|
||||
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
|
||||
}
|
||||
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
|
||||
|
||||
i_last[seq_id] = batch.n_tokens - 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
if (chain_heads) {
|
||||
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
if (dp.result->size() < (size_t) params.n_min) {
|
||||
dp.result->clear();
|
||||
}
|
||||
|
||||
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_eagle3) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
|
||||
}
|
||||
if (has_mtp) {
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DbrxForCausalLM": "dbrx",
|
||||
"DeciLMForCausalLM": "deci",
|
||||
"DeepseekForCausalLM": "deepseek",
|
||||
"DeepseekOCRForCausalLM": "deepseek",
|
||||
"DeepseekV2ForCausalLM": "deepseek",
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
@@ -96,6 +97,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -123,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LLaDAModelLM": "llada",
|
||||
"LLaMAForCausalLM": "llama",
|
||||
"Lfm25AudioTokenizer": "lfm2",
|
||||
"Lfm2BidirectionalModel": "lfm2",
|
||||
"Lfm2ForCausalLM": "lfm2",
|
||||
"Lfm2Model": "lfm2",
|
||||
"Lfm2MoeForCausalLM": "lfm2",
|
||||
@@ -231,6 +234,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"UMT5ForConditionalGeneration": "t5",
|
||||
"UMT5Model": "t5",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VLlama3ForCausalLM": "llama",
|
||||
"VoxtralForConditionalGeneration": "llama",
|
||||
"WavTokenizerDec": "wavtokenizer",
|
||||
@@ -261,6 +265,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
@@ -296,6 +301,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"StepVLForConditionalGeneration": "step3",
|
||||
"Step3p7ForConditionalGeneration": "step3",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VoxtralForConditionalGeneration": "ultravox",
|
||||
"YoutuVLForConditionalGeneration": "youtuvl",
|
||||
}
|
||||
|
||||
+10
-2
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
from .qwen import QwenModel
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
|
||||
class DeepseekOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
|
||||
@ModelBase.register(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekOCRForCausalLM",
|
||||
"UnlimitedOCRForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"KimiK25ForConditionalGeneration",
|
||||
"YoutuForCausalLM",
|
||||
@@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# special handling for Deepseek OCR
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
|
||||
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
@@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
|
||||
if is_ocr:
|
||||
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
|
||||
if sliding_window:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
|
||||
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
+10
-3
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hf_arch == "Lfm2BidirectionalModel":
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
# optional dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
if not tensors_file.is_file():
|
||||
return
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
|
||||
|
||||
```
|
||||
$ apt update && apt upgrade -y
|
||||
$ apt install git cmake
|
||||
$ apt install git cmake libandroid-spawn
|
||||
```
|
||||
|
||||
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -47,7 +46,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -73,7 +71,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "OFF",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -198,18 +198,18 @@ class BuiltinRule:
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'boolean' : BuiltinRule('("true" | "false")', []),
|
||||
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
|
||||
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
|
||||
'null' : BuiltinRule('"null"', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
@@ -319,7 +319,7 @@ class SchemaConverter:
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
@@ -549,7 +549,7 @@ class SchemaConverter:
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
@@ -580,10 +580,10 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
@@ -624,7 +624,7 @@ class SchemaConverter:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
@@ -638,12 +638,12 @@ class SchemaConverter:
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
' space "]"')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
@@ -663,7 +663,7 @@ class SchemaConverter:
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
@@ -680,7 +680,7 @@ class SchemaConverter:
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
out.append(")")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
@@ -765,7 +765,7 @@ class SchemaConverter:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
rule += ' space "}"'
|
||||
|
||||
return rule
|
||||
|
||||
|
||||
@@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"ggml: OpenCL API version to target")
|
||||
|
||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
+50
-23
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
const float * xf = (const float *) x;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, xf);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * yf = (float *) y;
|
||||
float variance = 0;
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
mean = -mean;
|
||||
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
|
||||
vDSP_measqv(yf, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, yf, scale);
|
||||
} else {
|
||||
float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += *(const float *) (x + i00*nb00);
|
||||
}
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float variance = 0.0f;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float v = *(const float *) (x + i00*nb00) - mean;
|
||||
*(float *) (y + i00*nb0) = v;
|
||||
variance += v * v;
|
||||
}
|
||||
variance /= ne00;
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
*(float *) (y + i00*nb0) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
sum += (ggml_float)(xi * xi);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, (float *) y, scale);
|
||||
} else {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
*(float *) (y + i00*nb0) = xi * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
|
||||
@@ -25,7 +25,6 @@ include(ExternalProject)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||
@@ -72,15 +71,12 @@ function(build_htp_skel V)
|
||||
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
|
||||
-DDSP_VERSION=${V}
|
||||
-DPREBUILT_LIB_DIR="toolv19_${V}")
|
||||
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
|
||||
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
build_htp_skel(v68)
|
||||
build_htp_skel(v69)
|
||||
build_htp_skel(v73)
|
||||
build_htp_skel(v75)
|
||||
build_htp_skel(v79)
|
||||
|
||||
+1359
-1274
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,12 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
|
||||
struct htp_opnode {
|
||||
ggml_tensor * node = nullptr;
|
||||
@@ -17,6 +19,13 @@ struct htp_opnode {
|
||||
|
||||
htp_op_code opcode = HTP_OP_INVALID;
|
||||
|
||||
std::vector<ggml_tensor *> extra_dsts;
|
||||
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
|
||||
|
||||
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
|
||||
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
|
||||
|
||||
ggml_op op() const {
|
||||
return node->op;
|
||||
}
|
||||
@@ -25,6 +34,26 @@ struct htp_opnode {
|
||||
return fused.empty() ? node : fused.back();
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t, bool extra_dst = false) {
|
||||
fused.push_back(t);
|
||||
if (extra_dst) {
|
||||
extra_dsts.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const ggml_tensor *> get_outputs() const {
|
||||
std::vector<const ggml_tensor *> res;
|
||||
if (extra_dsts.empty()) {
|
||||
res.push_back(dst());
|
||||
} else {
|
||||
res.push_back(node);
|
||||
for (const auto * x : extra_dsts) {
|
||||
res.push_back(x);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0() const {
|
||||
return node->src[0];
|
||||
}
|
||||
@@ -37,10 +66,6 @@ struct htp_opnode {
|
||||
return ggml_op_is_empty(node->op);
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t) {
|
||||
fused.push_back(t);
|
||||
}
|
||||
|
||||
bool stackable() const {
|
||||
switch (this->op()) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
@@ -131,87 +156,117 @@ struct htp_opformat {
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
char kparams[128];
|
||||
|
||||
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
} else {
|
||||
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_dims(char * str, const htp_opnode & node) {
|
||||
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_dims(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_dims(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_dims(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_dims(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
const char * c = ggml_is_contiguous(t) ? "" : "!";
|
||||
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
} else {
|
||||
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_strides(char * str, const htp_opnode & node) {
|
||||
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_strides(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_strides(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_strides(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_strides(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_types(char * str, const htp_opnode & node) {
|
||||
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
const char * tensor_buff_name(const struct ggml_tensor * t) {
|
||||
@@ -221,51 +276,102 @@ struct htp_opformat {
|
||||
return "NONE";
|
||||
}
|
||||
|
||||
void format_op_buffs(char * str, const htp_opnode & node) {
|
||||
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_names(char * str, const htp_opnode & node) {
|
||||
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", node.dst()->name);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
|
||||
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
|
||||
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
|
||||
path = "hmx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
|
||||
path = "hvx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
|
||||
path = "hvx-flat";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else {
|
||||
snprintf(str, max_size, "----");
|
||||
}
|
||||
}
|
||||
|
||||
void format(const htp_opnode & node) {
|
||||
format_op_dims(dims, node);
|
||||
format_op_strides(strides, node);
|
||||
format_op_types(types, node);
|
||||
format_op_buffs(buffs, node);
|
||||
format_op_names(names, node);
|
||||
format_op_dims(dims, sizeof(dims), node);
|
||||
format_op_strides(strides, sizeof(strides), node);
|
||||
format_op_types(types, sizeof(types), node);
|
||||
format_op_buffs(buffs, sizeof(buffs), node);
|
||||
format_op_names(names, sizeof(names), node);
|
||||
format_kernel_params(kparams, sizeof(kparams), node);
|
||||
}
|
||||
|
||||
htp_opformat() {}
|
||||
htp_opformat() {
|
||||
strides[0] = '\0';
|
||||
dims[0] = '\0';
|
||||
types[0] = '\0';
|
||||
buffs[0] = '\0';
|
||||
names[0] = '\0';
|
||||
kparams[0] = '\0';
|
||||
}
|
||||
htp_opformat(const htp_opnode & node) { format(node); }
|
||||
};
|
||||
|
||||
|
||||
@@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
hex-dma.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
flash-attn-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
@@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
@@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
pad-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
install(TARGETS ${HTP_LIB})
|
||||
|
||||
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||
endif()
|
||||
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||
|
||||
#Cross Compiling for Hexagon
|
||||
# Cross Compiling for Hexagon
|
||||
set(HEXAGON TRUE)
|
||||
set(CMAKE_SYSTEM_NAME QURT)
|
||||
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||
@@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CUSTOM_RUNELF_PATH "")
|
||||
|
||||
#To fix backward compatibility with EAI addon.
|
||||
if (NOT HEXAGON_SDK_ROOT)
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
endif()
|
||||
@@ -31,7 +30,6 @@ endif()
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
|
||||
#Get the Binary extension of the Hexagon Toolchain
|
||||
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||
endif()
|
||||
@@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||
HEXAGON_TOOLS_ROOT
|
||||
)
|
||||
|
||||
#QURT Related includes and linker flags
|
||||
# QURT Related includes and linker flags
|
||||
set(V_ARCH ${HEXAGON_ARCH})
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
|
||||
if( ${TREE} MATCHES PAKMAN )
|
||||
if (${TREE} MATCHES PAKMAN)
|
||||
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
endif()
|
||||
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||
@@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS
|
||||
)
|
||||
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||
|
||||
set(QURT_END_LINK_LIBS
|
||||
${TARGET_DIR}/fini.o
|
||||
)
|
||||
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
|
||||
|
||||
#Non QURT related includes and linker flags
|
||||
# Non QURT related includes and linker flags
|
||||
|
||||
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||
|
||||
@@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API)
|
||||
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||
endif()
|
||||
|
||||
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
|
||||
|
||||
set(PIC_SHARED_LD_FLAGS
|
||||
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||
${ARCH_FLAGS}
|
||||
-G0
|
||||
-fpic
|
||||
-Wl,-Bsymbolic
|
||||
@@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
#System include paths
|
||||
# System include paths
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||
|
||||
#LLVM toolchain setup
|
||||
#Compiler paths, options and architecture
|
||||
# LLVM toolchain setup
|
||||
# Compiler paths, options and architecture
|
||||
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
@@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
# Compiler Options
|
||||
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
@@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
@@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
#ifndef HEX_COMMON_H
|
||||
#define HEX_COMMON_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef SIZE_MAX
|
||||
#define SIZE_MAX ((size_t)-1)
|
||||
#endif
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a != 0 && b > SIZE_MAX / a) return true;
|
||||
*out = a * b;
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a > SIZE_MAX - b) return true;
|
||||
*out = a + b;
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif // HEX_COMMON_H
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <hexagon_types.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include "hex-profile.h"
|
||||
|
||||
@@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
|
||||
return p;
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ < 73
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 0;
|
||||
#else
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 1;
|
||||
#endif
|
||||
|
||||
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
|
||||
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
|
||||
|
||||
@@ -11,14 +11,7 @@
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
#include "hex-common.h"
|
||||
|
||||
static inline uint64_t hex_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
@@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() {
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
|
||||
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
|
||||
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
|
||||
+ k_dma_size * 2 // K DMA x2
|
||||
+ v_dma_size * 2 // V DMA x2
|
||||
+ k_tile_size * 1 // K tiles
|
||||
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ s_tile_size * 2 // S + P
|
||||
+ d_tile_size * 1 // D (diagonal matrix)
|
||||
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
|
||||
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
|
||||
|
||||
struct hmx_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
uint32_t n_threads;
|
||||
|
||||
// Op parameters
|
||||
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
return;
|
||||
}
|
||||
|
||||
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
|
||||
|
||||
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
|
||||
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
|
||||
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
|
||||
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
|
||||
const uint32_t n_threads = pipeline ? n_threads_init : 1;
|
||||
|
||||
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
|
||||
|
||||
// ======== Build context ========
|
||||
struct hmx_fa_context factx;
|
||||
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.n_kv_blocks = n_kv_blocks;
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
|
||||
factx.use_pipeline = use_pipeline;
|
||||
factx.pipeline = pipeline;
|
||||
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
|
||||
|
||||
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
|
||||
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
|
||||
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
|
||||
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
if (use_pipeline) {
|
||||
if (pipeline) {
|
||||
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
} else {
|
||||
factx.vtcm_v_tiles[1] = NULL;
|
||||
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
// ======== HMX lock strategy ========
|
||||
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
|
||||
// Fallback: main thread holds the lock (original behavior).
|
||||
if (!factx.use_pipeline) {
|
||||
if (!factx.pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
|
||||
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
// ==================================================================
|
||||
// Pipeline path: HVX phases ‖ HMX queue worker
|
||||
// ==================================================================
|
||||
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
|
||||
|
||||
// HMX: O_final = diag(1/l) @ O_prev
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
on_job.o_curr = o_tile_curr;
|
||||
on_job.o_prev = o_tile_prev;
|
||||
on_job.d_tiles = factx.vtcm_d_tiles;
|
||||
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
} // end KV head loop
|
||||
} // end batch loop
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
} else {
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
// HMX operations compiled as a single translation unit.
|
||||
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
|
||||
|
||||
#include "hmx-queue.c"
|
||||
#include "hmx-matmul-ops.c"
|
||||
#include "hmx-flash-attn-ops.c"
|
||||
@@ -1,88 +0,0 @@
|
||||
// HMX operation entry-point declarations.
|
||||
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_OPS_H
|
||||
#define HMX_OPS_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
const __fp16 *permuted_weight;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int act_stride;
|
||||
int weight_stride;
|
||||
int dst_stride;
|
||||
int ne02;
|
||||
int ne03;
|
||||
int ne12;
|
||||
int ne13;
|
||||
size_t src0_nb2;
|
||||
size_t src0_nb3;
|
||||
size_t src1_nb2;
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_f16_f32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
|
||||
int hmx_matmul_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride,
|
||||
int weight_type);
|
||||
|
||||
struct mmid_row_mapping;
|
||||
|
||||
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int ne11,
|
||||
size_t act_nb1, size_t act_nb2,
|
||||
size_t dst_nb1, size_t dst_nb2,
|
||||
int weight_stride,
|
||||
int weight_type,
|
||||
const struct mmid_row_mapping *matrix_rows,
|
||||
int cur_a,
|
||||
int mapping_stride);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HMX_OPS_H
|
||||
@@ -13,7 +13,9 @@
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#endif
|
||||
#define HTP_MAX_MMAPS 16
|
||||
|
||||
// Memory mapping
|
||||
@@ -42,9 +44,13 @@ struct htp_ops_context {
|
||||
|
||||
enum htp_op_code op; // FIXME: rename to opcode
|
||||
int32_t op_params[HTP_OP_MAX_PARAMS];
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
|
||||
|
||||
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
|
||||
const struct htp_tensor * dst;
|
||||
union {
|
||||
const struct htp_tensor * dst;
|
||||
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
|
||||
};
|
||||
|
||||
// TODO convert these to an array
|
||||
struct htp_spad src0_spad;
|
||||
@@ -87,13 +93,13 @@ struct htp_context {
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
|
||||
#endif
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_matmul_qkv(struct htp_ops_context * octx);
|
||||
int op_matmul_ffn(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
|
||||
@@ -28,18 +28,19 @@ enum htp_data_type {
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q4_1x4x2,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
HTP_TYPE_Q4_0_TILED = 200,
|
||||
HTP_TYPE_Q4_1_TILED,
|
||||
HTP_TYPE_Q8_0_TILED,
|
||||
HTP_TYPE_MXFP4_TILED,
|
||||
|
||||
HTP_TYPE_INVALID
|
||||
};
|
||||
|
||||
// Constats for internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
|
||||
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
|
||||
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
|
||||
|
||||
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
@@ -57,6 +58,8 @@ enum htp_op_code {
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_MUL_MAT_QKV,
|
||||
HTP_OP_MUL_MAT_FFN,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_RMS_NORM_MUL,
|
||||
HTP_OP_UNARY_SILU,
|
||||
@@ -99,7 +102,9 @@ enum htp_op_code {
|
||||
|
||||
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_OUTPUTS 4
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
#define HTP_OP_MAX_KERN_PARAMS 32
|
||||
|
||||
#define HTP_OP_MAX_BUFS 16
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
@@ -142,8 +147,10 @@ struct htp_op_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t flags; // Op flags
|
||||
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
|
||||
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
|
||||
uint16_t dst; // Output tensor index
|
||||
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
|
||||
uint16_t pad[2]; // padding to align to 64 bits
|
||||
};
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
|
||||
@@ -11,12 +11,13 @@ struct htp_iface_pmu_conf {
|
||||
};
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
|
||||
AEEResult etm(in uint32 enable);
|
||||
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
|
||||
};
|
||||
|
||||
#endif /* HTP_IDL */
|
||||
|
||||
@@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
{
|
||||
HVX_Vector const vzero = Q6_V_vzero();
|
||||
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
|
||||
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
|
||||
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
|
||||
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
|
||||
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
|
||||
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
||||
{
|
||||
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
@@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
|
||||
#endif // __HVX_ARCH__ < 79
|
||||
|
||||
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
|
||||
if (kt % 4 == 0) {
|
||||
*v_act_all = hvx_vmem(y_q + kt * 32);
|
||||
return *v_act_all;
|
||||
} else if (kt % 4 == 1) {
|
||||
return Q6_V_vror_VR(*v_act_all, 32);
|
||||
} else if (kt % 4 == 2) {
|
||||
return Q6_V_vror_VR(*v_act_all, 64);
|
||||
} else {
|
||||
return Q6_V_vror_VR(*v_act_all, 96);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
ctx->hmx_enabled = use_hmx;
|
||||
ctx->hmx_enabled = n_hmx;
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
if (n_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (ctx->hmx_queue) {
|
||||
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
|
||||
@@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
}
|
||||
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
|
||||
#endif
|
||||
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
@@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_queue) {
|
||||
hmx_queue_delete(ctx->hmx_queue);
|
||||
ctx->hmx_queue = NULL;
|
||||
}
|
||||
ctx->hmx_enabled = false;
|
||||
#endif
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
@@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
|
||||
(void)handle;
|
||||
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
|
||||
uint32_t n_hvx_val = hw_nhvx;
|
||||
if (n_hvx_val > hw_threads.max_hthreads) {
|
||||
n_hvx_val = hw_threads.max_hthreads;
|
||||
}
|
||||
if (n_hvx_val > HTP_MAX_NTHREADS) {
|
||||
n_hvx_val = HTP_MAX_NTHREADS;
|
||||
}
|
||||
|
||||
// for now we force n_threads == n_hvx
|
||||
*n_threads = n_hvx_val;
|
||||
*n_hvx = n_hvx_val;
|
||||
*n_hmx = 1;
|
||||
|
||||
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
|
||||
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
|
||||
*vtcm_size = vtcm_sz;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
// No errors expected on the DSP.
|
||||
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
|
||||
@@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
return op_matmul_id(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_QKV:
|
||||
return op_matmul_qkv(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_FFN:
|
||||
return op_matmul_ffn(octx);
|
||||
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
@@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
|
||||
}
|
||||
}
|
||||
|
||||
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
|
||||
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
|
||||
octx->flags = op->flags;
|
||||
octx->op = op->opcode;
|
||||
|
||||
@@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens,
|
||||
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
|
||||
}
|
||||
|
||||
// Prep output tensor
|
||||
struct htp_tensor *dst = tens + op->dst;
|
||||
// Prep output tensors
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
uint16_t dst_idx = op->dst[i];
|
||||
if (dst_idx == 0xffff) {
|
||||
octx->dsts[i] = NULL;
|
||||
continue;
|
||||
}
|
||||
struct htp_tensor *dst = tens + dst_idx;
|
||||
octx->dsts[i] = dst;
|
||||
|
||||
octx->dst = dst;
|
||||
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
|
||||
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
int status = execute_op(octx);
|
||||
|
||||
(void) execute_op(octx);
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
octx->src2_spad.src = NULL;
|
||||
octx->src3_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
// flush buffers on output
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
if (octx->dsts[i]) {
|
||||
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
|
||||
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
|
||||
@@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
}
|
||||
}
|
||||
|
||||
int op_status = HTP_STATUS_OK;
|
||||
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
|
||||
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
if (i == (n_ops-1)) {
|
||||
// wake up the host before starting the last op
|
||||
if (i == op_wakeup) {
|
||||
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
|
||||
}
|
||||
|
||||
profile_start(ctx->profiler, &prof);
|
||||
|
||||
proc_op_req(octx, tens, i, &ops[i]);
|
||||
op_status = proc_op_req(octx, tens, i, &ops[i]);
|
||||
|
||||
profile_stop(ctx->profiler, &prof);
|
||||
|
||||
if (op_status != HTP_STATUS_OK) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ctx->profiler) {
|
||||
pds[i].opcode = ops[i].opcode;
|
||||
pds[i].usecs = prof.usecs;
|
||||
@@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
struct htp_opbatch_rsp rsp;
|
||||
rsp.id = req.id;
|
||||
rsp.status = HTP_STATUS_OK;
|
||||
rsp.status = op_status;
|
||||
rsp.n_bufs = n_bufs;
|
||||
rsp.n_tensors = n_tens;
|
||||
rsp.n_ops = n_ops;
|
||||
|
||||
+2729
-4117
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,508 @@
|
||||
#ifndef HTP_MATMUL_OPS_H
|
||||
#define HTP_MATMUL_OPS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include "htp-ops.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --- HMX Tile Constraints ---
|
||||
#define HTP_MM_HMX_TILE_N_COLS 32
|
||||
#define HTP_MM_HMX_TILE_N_ROWS 32
|
||||
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
|
||||
#define HTP_MM_HMX_TILE_N_ELMS 1024
|
||||
#define HTP_MM_HMX_MIN_NROWS 4
|
||||
|
||||
// --- Weight Repacked Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
|
||||
|
||||
// --- Weight Repacked Aligned Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
|
||||
|
||||
// --- Activation Tiled Block Sizes (including padding) ---
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
|
||||
|
||||
#define HTP_MM_MAX_PREFETCH 16
|
||||
|
||||
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
|
||||
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
|
||||
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
|
||||
|
||||
// --- DMA Activation Transfer Configuration ---
|
||||
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
|
||||
#define HTP_MM_DMA_ACT_MULTIPLIER 4
|
||||
|
||||
enum htp_mm_kernel_type {
|
||||
HTP_MM_KERNEL_UNSUPPORTED = 0,
|
||||
|
||||
// HMX paths
|
||||
HTP_MM_KERNEL_HMX_2D,
|
||||
HTP_MM_KERNEL_HMX_F16_BATCHED,
|
||||
|
||||
// HVX floating-point paths
|
||||
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F16_F16_DDR,
|
||||
HTP_MM_KERNEL_HVX_F16_F32_DDR,
|
||||
|
||||
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F32_F32_DDR,
|
||||
HTP_MM_KERNEL_HVX_F32_F16_DDR,
|
||||
|
||||
// HVX quantized paths
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
|
||||
};
|
||||
|
||||
// Op-specific struct for precomputed matmul params
|
||||
struct htp_mm_kernel_params {
|
||||
int32_t kernel_type; // enum htp_mm_kernel_type
|
||||
int32_t pipeline; // 1 = pipelined execution, 0 = standard
|
||||
int32_t m_chunk; // Row chunk size (M chunk)
|
||||
int32_t n_chunk; // Col chunk size (N chunk)
|
||||
int32_t n_threads; // Number of threads to spawn
|
||||
int32_t n_act_threads; // Number of threads for activation preparation
|
||||
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
|
||||
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
|
||||
int32_t tile_size; // Weight tile size
|
||||
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
|
||||
int32_t src1_row_size; // Row size for quantized activation
|
||||
int32_t vtcm_size; // Total required scratchpad size in VTCM
|
||||
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
|
||||
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
|
||||
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
|
||||
|
||||
// Precomputed division values
|
||||
struct fastdiv_values div_ne12_ne1;
|
||||
struct fastdiv_values div_ne1;
|
||||
struct fastdiv_values div_r2;
|
||||
struct fastdiv_values div_r3;
|
||||
struct fastdiv_values div_ne11;
|
||||
};
|
||||
|
||||
#if defined(__cplusplus)
|
||||
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#else
|
||||
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#endif
|
||||
|
||||
struct mmid_row_mapping {
|
||||
uint32_t i1;
|
||||
uint32_t i2;
|
||||
};
|
||||
|
||||
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
|
||||
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
|
||||
size_t overhead,
|
||||
size_t per_n_cost,
|
||||
size_t per_m_cost,
|
||||
size_t per_mn_cost,
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t m_block_cost,
|
||||
size_t n_block_cost,
|
||||
size_t * m_chunk_out,
|
||||
size_t * n_chunk_out,
|
||||
size_t * total_out) {
|
||||
if (m == 0 || n == 0) return -1;
|
||||
if (vtcm_total <= overhead) return -1;
|
||||
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
|
||||
|
||||
const size_t usable = vtcm_total - overhead;
|
||||
|
||||
size_t best_cost = SIZE_MAX;
|
||||
size_t best_mn = 0;
|
||||
size_t best_m = 0, best_n = 0;
|
||||
|
||||
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
|
||||
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
|
||||
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
|
||||
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
|
||||
if (n_fixed >= usable) goto next_nc;
|
||||
|
||||
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
|
||||
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
|
||||
|
||||
{
|
||||
size_t remain = usable - n_fixed;
|
||||
size_t mc = remain / mc_denom;
|
||||
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
|
||||
mc = hex_smin(mc, m);
|
||||
|
||||
if (mc == 0) {
|
||||
goto next_nc;
|
||||
}
|
||||
|
||||
size_t mblocks = ((size_t) m + mc - 1) / mc;
|
||||
size_t nblocks = ((size_t) n + nc - 1) / nc;
|
||||
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
|
||||
size_t mn = mc * nc;
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_m = mc;
|
||||
best_n = nc;
|
||||
}
|
||||
}
|
||||
|
||||
next_nc:
|
||||
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
|
||||
}
|
||||
|
||||
if (best_m == 0 || best_n == 0) return -1;
|
||||
|
||||
// Compute exact total (with overflow checks)
|
||||
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
|
||||
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
|
||||
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
|
||||
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
|
||||
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
|
||||
if (hex_add_overflow(t0, t1, &total)) return -1;
|
||||
if (hex_add_overflow(total, t2, &total)) return -1;
|
||||
if (hex_add_overflow(total, overhead, &total)) return -1;
|
||||
|
||||
*m_chunk_out = best_m;
|
||||
*n_chunk_out = best_n;
|
||||
*total_out = total;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// --- Tile Size Helpers ---
|
||||
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Activation/Row Size Helpers ---
|
||||
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
|
||||
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
case HTP_TYPE_Q4_1:
|
||||
case HTP_TYPE_Q8_0:
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
|
||||
case HTP_TYPE_F16:
|
||||
return (size_t) k * sizeof(__fp16);
|
||||
case HTP_TYPE_F32:
|
||||
return (size_t) k * sizeof(float);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_round_up(size_t n, size_t m) {
|
||||
return ((n + m - 1) / m) * m;
|
||||
}
|
||||
|
||||
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
|
||||
return m > 32;
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_2d_chunk_costs(
|
||||
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
|
||||
|
||||
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
|
||||
(pipeline ? 2 * vec_dot_size : vec_dot_size);
|
||||
*size_per_m_out = vec_dot_size;
|
||||
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_batched_chunk_costs(
|
||||
uint32_t k, uint32_t group_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
*size_per_n_out = 3 * vec_dot_size;
|
||||
*size_per_m_out = group_size * vec_dot_size;
|
||||
*size_per_mn_out = sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
|
||||
) {
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
|
||||
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
|
||||
size_t weight_area_size = is_quant
|
||||
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
|
||||
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
|
||||
if (pipeline) {
|
||||
weight_area_size *= 2;
|
||||
}
|
||||
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
size_t scratch1_size = pipeline ? scratch0_size : 0;
|
||||
size_t scratch2_size = pipeline ? output_area_size : 0;
|
||||
|
||||
return weight_area_size + act_area_size + act_f32_size + output_area_size +
|
||||
scratch0_size + scratch1_size + scratch2_size + 256;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
|
||||
(void)wtype;
|
||||
(void)pipeline;
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const size_t f32_scratch_size = use_dma_activation
|
||||
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
|
||||
|
||||
const size_t act_head_stride = mc * k;
|
||||
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
return weight_area_size + act_area_size + output_area_size +
|
||||
2 * scratch_area_size + 256 + f32_scratch_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_get_vtcm_sizes(
|
||||
int kernel_type,
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows, // m_total (or act_nrows)
|
||||
uint32_t n_threads,
|
||||
size_t dst_row_size,
|
||||
size_t src0_row_size,
|
||||
size_t src1_row_size,
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out,
|
||||
size_t * vtcm_dst_size_out
|
||||
) {
|
||||
size_t vtcm_src0_size = 0;
|
||||
size_t vtcm_src1_size = 0;
|
||||
size_t vtcm_dst_size = 0;
|
||||
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
|
||||
|
||||
switch (kernel_type) {
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
|
||||
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
|
||||
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
|
||||
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = vtcm_src1_size;
|
||||
*vtcm_dst_size_out = vtcm_dst_size;
|
||||
|
||||
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows,
|
||||
uint32_t n_threads,
|
||||
size_t src0_row_size, // nb01
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out
|
||||
) {
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
|
||||
: htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
|
||||
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (src0_sz_per_thread < src1_row_size_padded) {
|
||||
src0_sz_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
||||
if (is_repack) {
|
||||
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
const uint32_t n_k_tiles = ne10 / 32;
|
||||
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
}
|
||||
|
||||
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = src1_sz;
|
||||
|
||||
return vtcm_src0_size + src1_sz;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HTP_MATMUL_OPS_H
|
||||
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
|
||||
// transposed into VTCM.
|
||||
//
|
||||
// VTCM layouts (per thread):
|
||||
// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile.
|
||||
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
|
||||
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
|
||||
//
|
||||
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
|
||||
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
|
||||
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
|
||||
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
|
||||
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
|
||||
static inline void transpose_src1(const float * src1_data,
|
||||
uint32_t src1_stride_inner,
|
||||
uint32_t i1_off,
|
||||
uint32_t d_inner_per_thread,
|
||||
uint32_t d_inner_stride,
|
||||
uint32_t d_conv,
|
||||
float * src1_T) {
|
||||
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
|
||||
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
src1_T[j * d_inner_per_thread + i] = src_row[j];
|
||||
src1_T[j * d_inner_stride + i] = src_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
}
|
||||
|
||||
const uint32_t d_inner_per_thread = ir1 - ir0;
|
||||
const uint32_t d_inner_stride = scctx->nrows_per_thread;
|
||||
const uint32_t d_inner_tile = scctx->d_inner_tile;
|
||||
|
||||
const float * src0_data = (const float *) src0->data;
|
||||
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
|
||||
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
|
||||
|
||||
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
|
||||
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
|
||||
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
|
||||
|
||||
const uint32_t C_TILE = VLEN_FP32;
|
||||
|
||||
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
|
||||
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t j = 0; j < d_conv; ++j) {
|
||||
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
|
||||
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
|
||||
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
|
||||
}
|
||||
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
|
||||
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
|
||||
use_hvx = 1;
|
||||
}
|
||||
|
||||
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
|
||||
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
|
||||
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
|
||||
|
||||
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
|
||||
const uint32_t ncs = src0->ne[0];
|
||||
|
||||
@@ -14,8 +14,6 @@ Drivers_Dir = 13
|
||||
1 = %DiskId%
|
||||
|
||||
[SourceDisksFiles]
|
||||
libggml-htp-v68.so = 1
|
||||
libggml-htp-v69.so = 1
|
||||
libggml-htp-v73.so = 1
|
||||
libggml-htp-v75.so = 1
|
||||
libggml-htp-v79.so = 1
|
||||
@@ -28,8 +26,6 @@ ExcludeFromSelect = *
|
||||
CopyFiles=Drivers_Dir
|
||||
|
||||
[Drivers_Dir]
|
||||
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data,
|
||||
(sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2,
|
||||
ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0),
|
||||
ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
|
||||
#endif
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
|
||||
|
||||
@@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) {
|
||||
return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_abs(T x) {
|
||||
return sycl::fabs(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed
|
||||
} else {
|
||||
return sycl::fabs(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::expm1(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::expm1(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_elu(T x) {
|
||||
return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
|
||||
return (x > static_cast<T>(0.f)) ? x : op_expm1(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
constexpr int ver = __INTEL_LLVM_COMPILER;
|
||||
#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000)
|
||||
return sycl::ext::oneapi::experimental::tanh(x);
|
||||
#else
|
||||
return static_cast<T>(sycl::tanh(static_cast<float>(x)));
|
||||
#endif
|
||||
} else {
|
||||
return sycl::tanh(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) {
|
||||
const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
|
||||
return static_cast<T>(0.5f) * x *
|
||||
(static_cast<T>(1.0f) +
|
||||
sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
op_tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::exp(x);
|
||||
} else {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_silu(T x) {
|
||||
return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return x / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
static __dpct_inline__ T op_erf(T x) {
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::erf(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::erf(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_gelu_erf(T x) {
|
||||
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
|
||||
return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + op_erf(x * SQRT_2_INV));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_tanh(T x) {
|
||||
return sycl::tanh(x);
|
||||
static __dpct_inline__ T op_gelu_quick(T x) {
|
||||
const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
|
||||
return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_relu(T x) {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0));
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sigmoid(T x) {
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
|
||||
return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + op_exp(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sqrt(T x) {
|
||||
return sycl::sqrt(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sqrt(x);
|
||||
} else {
|
||||
return sycl::sqrt(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_sin(T x) {
|
||||
return sycl::sin(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::sin(x);
|
||||
} else {
|
||||
return sycl::sin(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_cos(T x) {
|
||||
return sycl::cos(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::cos(x);
|
||||
} else {
|
||||
return sycl::cos(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardsigmoid(T x) {
|
||||
return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmin(
|
||||
static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(
|
||||
static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return sycl::fmin(static_cast<T>(1.0f),
|
||||
sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_hardswish(T x) {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_exp(T x) {
|
||||
return sycl::exp(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_expm1(T x) {
|
||||
return sycl::expm1(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return x * sycl::ext::oneapi::experimental::fmin(static_cast<T>(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
} else {
|
||||
return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) {
|
||||
if (x <= static_cast<T>(0)) {
|
||||
return neg_infinity<T>();
|
||||
}
|
||||
return sycl::log(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::log(x);
|
||||
} else {
|
||||
return sycl::log(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_softplus(T x) {
|
||||
const float xf = (float) x;
|
||||
const float ax = sycl::fabs(xf);
|
||||
const float ax = op_abs(xf);
|
||||
const float m = sycl::fmax(xf, 0.0f);
|
||||
const float y = m + sycl::log1p(sycl::exp(-ax));
|
||||
return (T) y;
|
||||
@@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) {
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
|
||||
T neg_slope_T = static_cast<T>(negative_slope);
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::fmax(x, static_cast<T>(0)) +
|
||||
sycl::ext::oneapi::experimental::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
|
||||
} else {
|
||||
return sycl::fmax(x, static_cast<T>(0)) +
|
||||
sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::floor(x);
|
||||
} else {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::ceil(x);
|
||||
} else {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<sycl::ext::oneapi::bfloat16>(
|
||||
sycl::round(static_cast<float>(x))
|
||||
);
|
||||
} else {
|
||||
return sycl::round(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) {
|
||||
return sycl::ext::oneapi::experimental::trunc(x);
|
||||
} else {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename F>
|
||||
@@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) {
|
||||
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
@@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_SYCL_HAS_BF16
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::ext::oneapi::bfloat16>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary(
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_generic_kernel(
|
||||
src, dst_ptr, k_elements,
|
||||
ne0, ne1, ne2, ne3,
|
||||
@@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
|
||||
});
|
||||
}, negative_slope);
|
||||
@@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
|
||||
});
|
||||
}, min_val, max_val);
|
||||
@@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
|
||||
const int64_t n, const int64_t o0, const int64_t o1,
|
||||
@@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x,
|
||||
const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
@@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -108,6 +108,9 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_CHECK_RESULTS)
|
||||
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
||||
# the result-checking path computes a CPU reference graph via
|
||||
# ggml_graph_compute_with_ctx(), which is defined in ggml-cpu
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
if (GGML_VULKAN_DEBUG)
|
||||
@@ -129,6 +132,8 @@ if (Vulkan_FOUND)
|
||||
|
||||
if (GGML_VULKAN_RUN_TESTS)
|
||||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
# the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu)
|
||||
target_link_libraries(ggml-vulkan PRIVATE ggml-cpu)
|
||||
endif()
|
||||
|
||||
# Set up toolchain for host compilation whether cross-compiling or not
|
||||
|
||||
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_conv3d_pipeline_state {
|
||||
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
|
||||
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
|
||||
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
|
||||
|
||||
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
|
||||
uint32_t aligned;
|
||||
|
||||
bool operator<(const vk_conv3d_pipeline_state &b) const {
|
||||
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
|
||||
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_solve_tri_pipeline_state {
|
||||
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
|
||||
: N(N), K(K) {}
|
||||
@@ -685,6 +699,7 @@ struct vk_device_struct {
|
||||
|
||||
bool add_rms_fusion;
|
||||
uint32_t partials_binding_alignment;
|
||||
uint32_t max_nodes_per_submit;
|
||||
|
||||
bool shader_64b_indexing;
|
||||
|
||||
@@ -777,6 +792,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_back_f32;
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
@@ -801,14 +817,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_diag[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_clamp[2];
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
|
||||
@@ -840,6 +852,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
vk_pipeline pipeline_silu[2];
|
||||
vk_pipeline pipeline_relu[2];
|
||||
vk_pipeline pipeline_sqr[2];
|
||||
vk_pipeline pipeline_sqrt[2];
|
||||
vk_pipeline pipeline_sin[2];
|
||||
vk_pipeline pipeline_cos[2];
|
||||
vk_pipeline pipeline_xielu[2];
|
||||
vk_pipeline pipeline_neg[2];
|
||||
vk_pipeline pipeline_tanh[2];
|
||||
@@ -871,7 +887,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_leaky_relu[2];
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
@@ -924,6 +940,8 @@ struct vk_device_struct {
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
@@ -1669,6 +1687,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
}
|
||||
|
||||
struct vk_op_conv3d_push_constants {
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
};
|
||||
|
||||
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
|
||||
}
|
||||
|
||||
struct vk_op_conv2d_dw_push_constants {
|
||||
uint32_t ne;
|
||||
uint32_t batches;
|
||||
@@ -4074,19 +4127,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#endif
|
||||
|
||||
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
return spec;
|
||||
};
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
|
||||
if (mul_mat_id && spec.size() > 5) {
|
||||
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
|
||||
} else {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
}
|
||||
if (mul_mat_id && spec.size() == 6) {
|
||||
spec.push_back(32);
|
||||
}
|
||||
return spec;
|
||||
};
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
@@ -4161,17 +4230,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -4284,32 +4353,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
// bf16 scalar path promotes to f32, no dot2 variant
|
||||
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
|
||||
@@ -4474,17 +4543,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) \
|
||||
@@ -4879,6 +4948,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -4903,7 +4973,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
@@ -5023,11 +5093,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -5037,8 +5102,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5058,6 +5121,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_UNARY(gelu_quick)
|
||||
CREATE_UNARY(silu)
|
||||
CREATE_UNARY(relu)
|
||||
CREATE_UNARY(sqr)
|
||||
CREATE_UNARY(sqrt)
|
||||
CREATE_UNARY(sin)
|
||||
CREATE_UNARY(cos)
|
||||
CREATE_UNARY(clamp)
|
||||
CREATE_UNARY(leaky_relu)
|
||||
CREATE_UNARY(xielu)
|
||||
CREATE_UNARY(neg)
|
||||
CREATE_UNARY(tanh)
|
||||
@@ -5097,7 +5166,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||
@@ -5314,7 +5382,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d, conv_transpose_2d
|
||||
// conv2d, conv_transpose_2d, conv3d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
|
||||
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
|
||||
@@ -5377,8 +5445,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
|
||||
};
|
||||
|
||||
// coopmat1 needs to store the output through shared memory, so check up front
|
||||
// whether it'll fit and disable it before applying coopmat1 parameters.
|
||||
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
|
||||
// layout. cm1 needs Csh for output, so check before applying cm1 params.
|
||||
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
|
||||
conv2d_use_cm1 = false;
|
||||
}
|
||||
@@ -5470,6 +5538,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#undef CREATE_CONV
|
||||
#undef CREATE_CONVS
|
||||
|
||||
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
|
||||
#define CREATE_CONV3D(type_suffix, spv_suffix) \
|
||||
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
|
||||
const vk_conv3d_pipeline_state &state = c.first; \
|
||||
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
|
||||
spec_constants_cpy.push_back(state.s0); \
|
||||
spec_constants_cpy.push_back(state.s1); \
|
||||
spec_constants_cpy.push_back(state.s2); \
|
||||
spec_constants_cpy.push_back(state.p0); \
|
||||
spec_constants_cpy.push_back(state.p1); \
|
||||
spec_constants_cpy.push_back(state.p2); \
|
||||
spec_constants_cpy.push_back(state.d0); \
|
||||
spec_constants_cpy.push_back(state.d1); \
|
||||
spec_constants_cpy.push_back(state.d2); \
|
||||
spec_constants_cpy.push_back(state.KW); \
|
||||
spec_constants_cpy.push_back(state.KH); \
|
||||
spec_constants_cpy.push_back(state.KD); \
|
||||
spec_constants_cpy.push_back(state.aligned); \
|
||||
spec_constants_cpy.push_back(conv2d_csh_store); \
|
||||
spec_constants_cpy.push_back(conv2d_WM); \
|
||||
spec_constants_cpy.push_back(conv2d_WN); \
|
||||
ggml_vk_create_pipeline( \
|
||||
device, c.second, "conv3d" #type_suffix, \
|
||||
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
|
||||
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
|
||||
}
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_CONV3D(_f32, _cm2)
|
||||
CREATE_CONV3D(_f16_f32, _cm2)
|
||||
} else
|
||||
#endif
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (conv2d_use_cm1) {
|
||||
CREATE_CONV3D(_f32, _cm1)
|
||||
CREATE_CONV3D(_f16_f32, _cm1)
|
||||
} else
|
||||
#endif
|
||||
if (conv2d_UNROLL) {
|
||||
CREATE_CONV3D(_f32, _unroll)
|
||||
CREATE_CONV3D(_f16_f32, _unroll)
|
||||
} else {
|
||||
CREATE_CONV3D(_f32, )
|
||||
CREATE_CONV3D(_f16_f32, )
|
||||
}
|
||||
#undef CREATE_CONV3D
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5764,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
|
||||
|
||||
// Submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||
device->max_nodes_per_submit = 100;
|
||||
const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT");
|
||||
if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) {
|
||||
uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT);
|
||||
device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u);
|
||||
}
|
||||
|
||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
@@ -10294,6 +10417,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_get_rows_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
@@ -10400,23 +10528,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQR:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqr_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQRT:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqrt_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LOG:
|
||||
@@ -10438,8 +10570,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_PAD:
|
||||
@@ -10807,8 +10940,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D:
|
||||
@@ -10885,6 +11019,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_3D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
|
||||
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
|
||||
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
|
||||
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
|
||||
|
||||
const uint32_t KW = (uint32_t)src0->ne[0];
|
||||
const uint32_t KH = (uint32_t)src0->ne[1];
|
||||
const uint32_t KD = (uint32_t)src0->ne[2];
|
||||
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
|
||||
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
|
||||
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
|
||||
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
|
||||
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
|
||||
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
|
||||
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
|
||||
|
||||
const uint32_t CRS = IC * KW * KH * KD;
|
||||
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
|
||||
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
|
||||
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
|
||||
const uint32_t aligned = ((OC % BS_K == 0) &&
|
||||
(CRS % BS_CRS == 0) &&
|
||||
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
|
||||
|
||||
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
|
||||
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
||||
auto it = pipelines->find(conv3d_pipeline_state);
|
||||
if (it != pipelines->end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
}
|
||||
}
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD1:
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f16;
|
||||
@@ -11135,6 +11324,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
@@ -11220,6 +11413,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
GGML_ABORT("invalid push constant type for CONV_2D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
|
||||
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
|
||||
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
|
||||
|
||||
elements = { pc.OC, NPQ_blocks, 1 };
|
||||
if (elements[1] > 512) {
|
||||
elements[2] = CEIL_DIV(elements[1], 512);
|
||||
elements[1] = 512;
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("invalid push constant type for CONV_3D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
@@ -11236,6 +11444,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11380,6 +11589,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
@@ -12087,8 +12311,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
|
||||
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -13118,6 +13344,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
vk_op_conv3d_push_constants p{};
|
||||
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
|
||||
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
|
||||
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
|
||||
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
|
||||
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
|
||||
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
|
||||
|
||||
p.IW = static_cast<uint32_t>(ne10);
|
||||
p.IH = static_cast<uint32_t>(ne11);
|
||||
p.ID = static_cast<uint32_t>(ne12);
|
||||
p.OW = static_cast<uint32_t>(ne0);
|
||||
p.OH = static_cast<uint32_t>(ne1);
|
||||
p.OD = static_cast<uint32_t>(ne2);
|
||||
|
||||
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
|
||||
// total input element count must fit in a uint32.
|
||||
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
|
||||
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
||||
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
||||
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
||||
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
vk_op_conv2d_dw_push_constants p{};
|
||||
p.ne = ggml_nelements(dst);
|
||||
@@ -13144,7 +13415,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_RUN_TESTS
|
||||
@@ -14247,6 +14521,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
@@ -14515,6 +14793,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
|
||||
@@ -15900,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
|
||||
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
|
||||
// (and scaled down based on model size, so smaller models submit earlier).
|
||||
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
|
||||
int nodes_per_submit = 100;
|
||||
int submitted_nodes = 0;
|
||||
int submit_count = 0;
|
||||
uint64_t mul_mat_bytes = 0;
|
||||
@@ -16127,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
|
||||
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
||||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||
bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) ||
|
||||
(mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||
@@ -16964,6 +17244,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
switch (op->type) {
|
||||
@@ -17060,12 +17342,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -17084,8 +17365,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
@@ -17285,6 +17567,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op));
|
||||
}
|
||||
case GGML_OP_CONV_3D:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -18128,6 +18417,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
const int32_t d0 = tensor->op_params[4];
|
||||
const int32_t d1 = tensor->op_params[5];
|
||||
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
||||
} else if (tensor->op == GGML_OP_CONV_3D) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
const int32_t s2 = tensor->op_params[2];
|
||||
const int32_t p0 = tensor->op_params[3];
|
||||
const int32_t p1 = tensor->op_params[4];
|
||||
const int32_t p2 = tensor->op_params[5];
|
||||
const int32_t d0 = tensor->op_params[6];
|
||||
const int32_t d1 = tensor->op_params[7];
|
||||
const int32_t d2 = tensor->op_params[8];
|
||||
const int32_t IC = tensor->op_params[9];
|
||||
const int32_t N = tensor->op_params[10];
|
||||
const int32_t OC = tensor->op_params[11];
|
||||
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
|
||||
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
}
|
||||
@@ -0,0 +1,431 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#ifdef COOPMAT2
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, KD, IC*OC]
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
|
||||
|
||||
layout(binding = 2) writeonly buffer D {
|
||||
D_TYPE dst_data[];
|
||||
}; // dst - result: [OW, OH, OD, OC*N]
|
||||
|
||||
layout(push_constant) uniform parameter {
|
||||
// I/O channels, batch size
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
// Tensor spatial sizes: input, output
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
// Strides in elements
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
// fastdiv helper values
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
}
|
||||
|
||||
p;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
// Blocktile sizes
|
||||
layout(constant_id = 1) const uint BS_K = 128;
|
||||
layout(constant_id = 2) const uint BS_CRS = 16;
|
||||
layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||
// Thread-tile sizes
|
||||
layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint SHMEM_PAD = 4;
|
||||
// Stride, padding, dilation
|
||||
layout(constant_id = 6) const uint s0 = 1;
|
||||
layout(constant_id = 7) const uint s1 = 1;
|
||||
layout(constant_id = 8) const uint s2 = 1;
|
||||
layout(constant_id = 9) const uint p0 = 0;
|
||||
layout(constant_id = 10) const uint p1 = 0;
|
||||
layout(constant_id = 11) const uint p2 = 0;
|
||||
layout(constant_id = 12) const uint d0 = 1;
|
||||
layout(constant_id = 13) const uint d1 = 1;
|
||||
layout(constant_id = 14) const uint d2 = 1;
|
||||
// Kernel spatial sizes
|
||||
layout(constant_id = 15) const uint KW = 1;
|
||||
layout(constant_id = 16) const uint KH = 1;
|
||||
layout(constant_id = 17) const uint KD = 1;
|
||||
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
|
||||
layout(constant_id = 18) const uint aligned = 0;
|
||||
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
|
||||
layout(constant_id = 19) const uint csh_store = 0;
|
||||
|
||||
#ifdef COOPMAT
|
||||
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
|
||||
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
|
||||
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
|
||||
layout(constant_id = 20) const uint WM = 32;
|
||||
layout(constant_id = 21) const uint WN = 32;
|
||||
const uint TM = 16;
|
||||
const uint TN = 16;
|
||||
const uint TK = 16;
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
const uint warps_M = BS_K / WM;
|
||||
const uint warps_N = BS_NPQ / WN;
|
||||
#endif
|
||||
|
||||
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
|
||||
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
uint splitWork(uint work_size, uint block_size) {
|
||||
return (block_size + work_size - 1) / block_size;
|
||||
}
|
||||
|
||||
uint32_t K = p.OC;
|
||||
uint32_t CRS = p.IC * KD * KH * KW;
|
||||
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
|
||||
|
||||
// Number of blocktiles per input
|
||||
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
#define SHMEM_TYPE float16_t
|
||||
#else
|
||||
#define SHMEM_TYPE float
|
||||
#endif
|
||||
|
||||
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
||||
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
||||
|
||||
const uint32_t Ash_len = BS_K * Ash_stride;
|
||||
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
||||
|
||||
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
|
||||
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
|
||||
const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
// Threadtile sizes
|
||||
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
||||
|
||||
// Number of threadtiles per blocktile
|
||||
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
/*
|
||||
Compute
|
||||
KxCRS @ CRSxNPQ = K x NPQ
|
||||
K=OC
|
||||
C=IC
|
||||
D,R,S=KD,KH,KW
|
||||
Z,P,Q=OD,OH,OW
|
||||
*/
|
||||
|
||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
|
||||
|
||||
uint32_t T_y = tid / NT_NPQ;
|
||||
uint32_t T_x = tid % NT_NPQ;
|
||||
|
||||
uint32_t Ar = tid / BS_CRS;
|
||||
uint32_t Ac = tid % BS_CRS;
|
||||
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
||||
|
||||
uint32_t Br = tid / BS_NPQ;
|
||||
uint32_t Bc = tid % BS_NPQ;
|
||||
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
|
||||
const uint32_t KHKW = KH * KW;
|
||||
const uint32_t KDKHKW = KD * KHKW;
|
||||
ic = crs_idx / KDKHKW;
|
||||
uint32_t rem = crs_idx - ic * KDKHKW;
|
||||
kd = rem / KHKW;
|
||||
rem = rem - kd * KHKW;
|
||||
kh = rem / KW;
|
||||
kw = rem - kh * KW;
|
||||
}
|
||||
|
||||
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
|
||||
const uint32_t OWOH = p.OW * p.OH;
|
||||
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
|
||||
uint32_t rem = npq_idx - n * p.OD * OWOH;
|
||||
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
|
||||
rem = rem - od * OWOH;
|
||||
oh = fastdiv(rem, p.OWmp, p.OWL);
|
||||
ow = rem - oh * p.OW;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
#define ACC_TYPE float16_t
|
||||
|
||||
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
|
||||
{
|
||||
uint32_t K_idx = B_idx_K * BS_K + r;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
if (B_idx_NPQ * BS_NPQ >= NPQ) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
||||
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
||||
#elif defined(COOPMAT)
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
||||
}
|
||||
const uint warp_r = gl_SubgroupID / warps_N;
|
||||
const uint warp_c = gl_SubgroupID % warps_N;
|
||||
#else
|
||||
float regC[TS_K][TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = 0.0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
/* Advance block in CRS dim */
|
||||
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
|
||||
uint32_t IC_idx_a;
|
||||
uint32_t KD_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
uint32_t KW_idx_a;
|
||||
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
|
||||
if (aligned == 0) {
|
||||
knl_idx = min(knl_idx, K * CRS - 1);
|
||||
}
|
||||
float val = knl_data[knl_idx];
|
||||
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
|
||||
val = 0.0;
|
||||
}
|
||||
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
||||
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
||||
uint32_t B_lx = Bc;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
|
||||
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
|
||||
uint32_t IC_idx_b;
|
||||
uint32_t KD_idx_b;
|
||||
uint32_t KH_idx_b;
|
||||
uint32_t KW_idx_b;
|
||||
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
|
||||
|
||||
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
|
||||
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
|
||||
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
|
||||
|
||||
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
|
||||
// skip clamp when address can't go OOB
|
||||
if (aligned == 0 || !dhw_in_bounds) {
|
||||
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
|
||||
}
|
||||
float val = src_data[src_idx];
|
||||
bool oob = false;
|
||||
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
|
||||
oob = true;
|
||||
}
|
||||
// also catches lower-bound underflow (idx wraps to 0x80000000+)
|
||||
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
|
||||
oob = true;
|
||||
}
|
||||
if (oob) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
barrier();
|
||||
#ifdef COOPMAT2
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
|
||||
|
||||
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
matC = coopMatMulAdd(matA, matB, matC);
|
||||
#elif defined(COOPMAT)
|
||||
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
|
||||
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
|
||||
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
||||
float regA[TS_K];
|
||||
float regB[TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
||||
}
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
||||
}
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
/* Save C* */
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
|
||||
#ifdef COOPMAT
|
||||
const bool use_staged_store = true;
|
||||
#else
|
||||
const bool use_staged_store = (csh_store != 0);
|
||||
#endif
|
||||
if (use_staged_store) {
|
||||
#ifdef COOPMAT
|
||||
// cm1: each subgroup stores its fragment grid into its Csh slot
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
#endif
|
||||
barrier();
|
||||
|
||||
// cooperative shmem->global: WG threads spread across BS_NPQ (the
|
||||
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
|
||||
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
|
||||
const uint32_t store_iters = BS_K / store_rows_per_iter;
|
||||
const uint32_t k_thread_offset = tid / BS_NPQ;
|
||||
const uint32_t npq_thread = tid % BS_NPQ;
|
||||
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
|
||||
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
|
||||
uint32_t K_idx = B_idx_K * BS_K + k_local;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef COOPMAT2
|
||||
else {
|
||||
coopMatPerElementNV(matC, matC, perElemOpStore);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
||||
@@ -463,6 +463,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
|
||||
@@ -352,6 +352,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
if (col >= p.ne20) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
|
||||
float sum = 0.0f;
|
||||
for (uint i = 0; i < p.ne10; ++i) {
|
||||
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
|
||||
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
|
||||
}
|
||||
}
|
||||
@@ -14,16 +14,13 @@ void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -39,6 +36,6 @@ void main() {
|
||||
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float val = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
|
||||
}
|
||||
@@ -38,17 +38,7 @@
|
||||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
|
||||
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_A 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_A 1
|
||||
#endif
|
||||
#if !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_B 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_B 1
|
||||
#endif
|
||||
layout (constant_id = 11) const uint ALIGNED = 0;
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||
@@ -57,6 +47,13 @@
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
|
||||
#elif defined(DATA_A_F16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
|
||||
#elif defined(DATA_A_BF16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -194,13 +192,23 @@ void main() {
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
|
||||
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
|
||||
#else
|
||||
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
|
||||
const uint LOAD_VEC_BATCH_A = 1;
|
||||
#endif
|
||||
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
|
||||
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
@@ -239,15 +247,15 @@ void main() {
|
||||
|
||||
uint pos_a =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#else
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#endif
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_b = 0;
|
||||
#else
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
@@ -287,8 +295,8 @@ void main() {
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
pos_a += BK / LOAD_VEC_A_EFF;
|
||||
pos_b += BK / LOAD_VEC_B_EFF;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
|
||||
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||
layout (constant_id = 5) const uint ALIGNED = 0;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
layout (constant_id = 6) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
@@ -297,12 +298,12 @@ void main() {
|
||||
|
||||
// Hint to the compiler that values are aligned (want 16B alignment).
|
||||
// Quants are always block-aligned, no alignment needed.
|
||||
#if ALIGNED
|
||||
if (ALIGNED != 0) {
|
||||
#if QUANT_K == 1
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
}
|
||||
|
||||
// Create layouts for both clamped and unclamped accesses
|
||||
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
||||
|
||||
@@ -1,50 +1,57 @@
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
|
||||
data_a_scalar[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared vec2 sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = vec2(0.0f, 0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid].x += xi;
|
||||
sum[tid].y += xi * xi;
|
||||
}
|
||||
@@ -34,11 +34,11 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
const float mean = sum[0].x / p.KX;
|
||||
const float var = sum[0].y / p.KX - mean * mean;
|
||||
const float mean = sum[0].x / p.ne00;
|
||||
const float var = sum[0].y / p.ne00 - mean * mean;
|
||||
const float inv_std = inversesqrt(var + p.param1);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||
}
|
||||
@@ -17,6 +17,30 @@ float op_neg(float x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
float op_sqr(float x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
float op_sqrt(float x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
float op_sin(float x) {
|
||||
return sin(x);
|
||||
}
|
||||
|
||||
float op_cos(float x) {
|
||||
return cos(x);
|
||||
}
|
||||
|
||||
float op_clamp(float x) {
|
||||
return clamp(x, p.param1, p.param2);
|
||||
}
|
||||
|
||||
float op_leaky_relu(float x) {
|
||||
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
|
||||
}
|
||||
|
||||
float op_step(float x) {
|
||||
return x >= 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
@@ -34,6 +35,9 @@
|
||||
|
||||
std::mutex lock;
|
||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
|
||||
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
|
||||
static std::atomic<bool> compile_failed{false};
|
||||
std::locale c_locale("C");
|
||||
|
||||
std::string GLSLC = "glslc";
|
||||
@@ -78,7 +82,7 @@ enum MatMulIdType {
|
||||
|
||||
namespace {
|
||||
|
||||
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
CloseHandle(stdout_read);
|
||||
CloseHandle(stderr_read);
|
||||
WaitForSingleObject(pi.hProcess, INFINITE);
|
||||
DWORD exit_code = 1;
|
||||
GetExitCodeProcess(pi.hProcess, &exit_code);
|
||||
CloseHandle(pi.hProcess);
|
||||
CloseHandle(pi.hThread);
|
||||
return (int)exit_code;
|
||||
#else
|
||||
int stdout_pipe[2];
|
||||
int stderr_pipe[2];
|
||||
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
waitpid(pid, nullptr, 0);
|
||||
int status = 0;
|
||||
waitpid(pid, &status, 0);
|
||||
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(cmd, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n";
|
||||
int exit_code = execute_command(cmd, stdout_str, stderr_str);
|
||||
if (exit_code != 0 || !stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
|
||||
for (const auto& part : cmd) {
|
||||
std::cerr << part << " ";
|
||||
}
|
||||
std::cerr << "\n\n" << stderr_str << std::endl;
|
||||
compile_failed = true;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
shader_fnames.push_back(std::make_pair(name, out_path));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
||||
compile_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
@@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
#endif
|
||||
{
|
||||
if (!dot2) {
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
@@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
@@ -850,21 +854,12 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
|
||||
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
|
||||
@@ -891,6 +886,18 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
|
||||
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
|
||||
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
|
||||
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
|
||||
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
|
||||
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
|
||||
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
|
||||
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
|
||||
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
|
||||
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
|
||||
@@ -948,7 +955,6 @@ void process_shaders() {
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -1060,6 +1066,31 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto unroll : {false, true}) {
|
||||
for (auto a_f16 : {false, true}) {
|
||||
std::map<std::string, std::string> defines = {
|
||||
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
|
||||
{"UNROLL", unroll ? "[[unroll]]" : ""},
|
||||
};
|
||||
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
|
||||
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm2_defines = defines;
|
||||
cm2_defines["COOPMAT2"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm1_defines = defines;
|
||||
cm1_defines["COOPMAT"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
@@ -1251,6 +1282,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
process_shaders();
|
||||
|
||||
if (compile_failed) {
|
||||
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
write_output_files();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
|
||||
@@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
int vectorized;
|
||||
uint32_t num_cols;
|
||||
bool use_mmvq;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
|
||||
use_mmvq == other.use_mmvq;
|
||||
num_cols == other.num_cols && use_mmvq == other.use_mmvq;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.use_mmvq);
|
||||
return seed;
|
||||
}
|
||||
@@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key {
|
||||
ggml_type src0_type;
|
||||
ggml_type src1_type;
|
||||
uint32_t n_experts;
|
||||
uint32_t num_cols;
|
||||
int vectorized;
|
||||
|
||||
bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const {
|
||||
return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts &&
|
||||
vectorized == other.vectorized;
|
||||
num_cols == other.num_cols && vectorized == other.vectorized;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash {
|
||||
ggml_webgpu_hash_combine(seed, key.src0_type);
|
||||
ggml_webgpu_hash_combine(seed, key.src1_type);
|
||||
ggml_webgpu_hash_combine(seed, key.n_experts);
|
||||
ggml_webgpu_hash_combine(seed, key.num_cols);
|
||||
ggml_webgpu_hash_combine(seed, key.vectorized);
|
||||
return seed;
|
||||
}
|
||||
@@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
bool supports_dot_product,
|
||||
const std::string & vendor) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] <= 4) {
|
||||
bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia";
|
||||
if (supports_dp4a && supports_dot_product) {
|
||||
switch (src1->type) {
|
||||
@@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib {
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.num_cols = context.dst->ne[1];
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols));
|
||||
|
||||
auto processed = preprocessor.preprocess(shader_src, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>();
|
||||
@@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib {
|
||||
if (key.vectorized) {
|
||||
variant += "_vectorized";
|
||||
}
|
||||
defines.push_back(std::string("NUM_COLS=1"));
|
||||
|
||||
defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts));
|
||||
|
||||
|
||||
@@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
|
||||
const size_t q8_src1_align_offset = ROUNDUP_POW2(
|
||||
dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t q8_src1_binding_size =
|
||||
ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t q8_src1_binding_size = ROUNDUP_POW2(
|
||||
src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
|
||||
std::vector<uint32_t> q8_params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) src1->ne[0],
|
||||
(uint32_t) src1->ne[1],
|
||||
(uint32_t) src1->ne[2],
|
||||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
@@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context &
|
||||
uint32_t q8_wg_x = 1;
|
||||
uint32_t q8_wg_y = 1;
|
||||
const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size;
|
||||
const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y);
|
||||
|
||||
@@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
// Determine if this is a mat-vec operation
|
||||
bool is_vec = (dst->ne[1] == 1);
|
||||
bool use_mat_vec = (dst->ne[1] <= 4);
|
||||
|
||||
// use MMVQ path for mat-vec
|
||||
bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product,
|
||||
@@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
webgpu_pipeline pipeline;
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
if (use_mmvq) {
|
||||
ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches);
|
||||
}
|
||||
@@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t wg_y = 1;
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (is_vec) {
|
||||
if (use_mat_vec) {
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
@@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product,
|
||||
ctx->webgpu_global_ctx->vendor);
|
||||
if (use_mmvq) {
|
||||
const size_t q8_src1_size =
|
||||
src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] *
|
||||
(36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32));
|
||||
res = ROUNDUP_POW2(res + q8_src1_size +
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
@@ -4268,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||
|
||||
@@ -103,7 +103,7 @@ fn main(
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
let subgroup_total = subgroupAdd(acc[0][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
@@ -126,7 +126,7 @@ fn main(
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
partial_sums[partial_index(row, thread_id)] = acc[0][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -91,61 +91,67 @@ fn main(
|
||||
let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base;
|
||||
|
||||
#ifdef MMVQ
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u);
|
||||
let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u);
|
||||
let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base);
|
||||
#else
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base);
|
||||
#endif
|
||||
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[col][row]);
|
||||
if (subgroup_invocation_id == 0u) {
|
||||
partial_sums[partial_index(row, subgroup_id)] = subgroup_total;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
workgroupBarrier();
|
||||
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + row] = row_total;
|
||||
}
|
||||
}
|
||||
for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) {
|
||||
let output_row = row_base + row;
|
||||
var row_acc = 0.0f;
|
||||
for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) {
|
||||
row_acc += partial_sums[partial_index(row, k)];
|
||||
}
|
||||
let row_total = subgroupAdd(row_acc);
|
||||
if (subgroup_invocation_id == 0) {
|
||||
dst[dst_idx_base + col * params.m + row] = row_total;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_WORKGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[row];
|
||||
}
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] = acc[col][row];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
var stride = WG_SIZE / 2u;
|
||||
|
||||
while (stride > 0) {
|
||||
if (thread_id < stride) {
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
stride = stride / 2;
|
||||
}
|
||||
|
||||
if (thread_id < OUTPUTS_PER_WG) {
|
||||
let output_row = row_base + thread_id;
|
||||
if (output_row < params.m) {
|
||||
dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
#define BLOCK_SIZE_BYTES 20
|
||||
@@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2<f32> {
|
||||
f32(load_f16_at_src0(block_byte_base + 2u))
|
||||
);
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
#define BLOCK_SIZE_BYTES 34
|
||||
@@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE {
|
||||
fn get_dm(block_byte_base: u32) -> f32 {
|
||||
return f32(load_f16_at_src0(block_byte_base));
|
||||
}
|
||||
fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 {
|
||||
return f32(row_sum) * (da * b_ds);
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q8_0
|
||||
|
||||
#ifdef LEGACY_QUANTS
|
||||
fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
var row_sum = 0;
|
||||
let a_repacked = repack_a(a_byte_base, b_inner_id);
|
||||
|
||||
row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]);
|
||||
row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds);
|
||||
}
|
||||
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
#if defined(LEGACY_QUANTS)
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let b_inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
let b_block_idx = src1q_idx_base + block;
|
||||
|
||||
let b_repacked = repack_b_qs(b_block_idx, b_inner_id);
|
||||
let b_ds = repack_b_dm(b_block_idx);
|
||||
|
||||
let inner_id = thread_id % THREADS_PER_BLOCK;
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, inner_id);
|
||||
let da = get_dm(block_byte_base);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block;
|
||||
let b_repacked = repack_b_qs(src1q_idx, inner_id);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]);
|
||||
|
||||
#if defined(MUL_ACC_Q4_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_0
|
||||
|
||||
#if defined(MUL_ACC_Q4_1)
|
||||
acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK;
|
||||
#endif // MUL_ACC_Q4_1
|
||||
|
||||
#if defined(MUL_ACC_Q8_0)
|
||||
acc[col][row] += f32(row_sum) * (da * b_ds);
|
||||
#endif // MUL_ACC_Q8_0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // LEGACY_QUANTS
|
||||
|
||||
#ifdef MUL_ACC_Q2_K
|
||||
#define BLOCK_SIZE_BYTES 84
|
||||
@@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u);
|
||||
return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u));
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#ifdef MUL_ACC_Q4_K
|
||||
#define BLOCK_SIZE_BYTES 144
|
||||
@@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> {
|
||||
|
||||
return vec2<f32>(scale, min_val);
|
||||
}
|
||||
fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 {
|
||||
let a_repacked = repack_a(a_byte_base, tid);
|
||||
let dm = get_dm(a_byte_base);
|
||||
let scale_min = get_scale_min(a_byte_base, tid);
|
||||
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
}
|
||||
#endif
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
#ifdef K_QUANTS
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> {
|
||||
var acc: array<f32, OUTPUTS_PER_WG>;
|
||||
fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<array<f32, OUTPUTS_PER_WG>, NUM_COLS> {
|
||||
var acc: array<array<f32, OUTPUTS_PER_WG>, NUM_COLS>;
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds);
|
||||
let a_repacked = repack_a(block_byte_base, tid);
|
||||
let dm = get_dm(block_byte_base);
|
||||
let scale_min = get_scale_min(block_byte_base, tid);
|
||||
for (var col = 0u;col < NUM_COLS;col += 1) {
|
||||
let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE;
|
||||
let b_repacked = repack_b_qs(src1q_idx, tid);
|
||||
let b_ds = repack_b_dm(src1q_idx);
|
||||
|
||||
#if defined(MUL_ACC_Q2_K)
|
||||
let scale_q = i32(scale_min.x);
|
||||
let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u;
|
||||
|
||||
let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1])
|
||||
+ dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q;
|
||||
let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4)
|
||||
+ dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4);
|
||||
|
||||
acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m));
|
||||
#endif // MUL_ACC_Q2_K
|
||||
|
||||
#if defined(MUL_ACC_Q4_K)
|
||||
let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1])
|
||||
+ dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]);
|
||||
|
||||
// Each thread covers half of the Q8_1 block, so add only b_ds.y/2.
|
||||
acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD));
|
||||
#endif // MUL_ACC_Q4_K
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
#endif
|
||||
#endif // K_QUANTS
|
||||
|
||||
@@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product;
|
||||
|
||||
struct Params {
|
||||
offset_src1: u32,
|
||||
stride_11: u32,
|
||||
stride_12: u32,
|
||||
stride_13: u32,
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
ne3: u32,
|
||||
};
|
||||
@@ -57,25 +59,28 @@ fn main(
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>
|
||||
) {
|
||||
let thread_id = local_id.x;
|
||||
let num_vec4 = params.ne0 / 4u;
|
||||
let ne0_vec4 = params.ne0 / 4u;
|
||||
|
||||
let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne2 * params.ne3;
|
||||
let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE;
|
||||
let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3;
|
||||
|
||||
let wg_linear = wg_id.y * num_wg.x + wg_id.x;
|
||||
if (wg_linear >= total_batches) {
|
||||
return;
|
||||
}
|
||||
|
||||
let src13_idx = wg_linear / (params.ne2 * wg_per_vec);
|
||||
let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
|
||||
let vec_idx = wg_linear / wg_per_vec;
|
||||
let src13_idx = vec_idx / (params.ne2 * params.ne1);
|
||||
let vec_ne12_num = vec_idx % (params.ne2 * params.ne1);
|
||||
let src12_idx = vec_ne12_num / params.ne1;
|
||||
let src11_idx = vec_ne12_num % params.ne1;
|
||||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11;
|
||||
let src1_idx_vec4_base = src1_idx_base / 4u;
|
||||
|
||||
let blocks_per_row = params.ne0 / 32u;
|
||||
let blocks_per_wg = (WG_SIZE * 4u) / 32u;
|
||||
let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row;
|
||||
let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row;
|
||||
let src11_wg_idx = wg_linear % wg_per_vec;
|
||||
let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u;
|
||||
let qs_idx = thread_id % 8u;
|
||||
|
||||
@@ -85,7 +90,7 @@ fn main(
|
||||
var thread_amax = 0.0;
|
||||
|
||||
let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id;
|
||||
let is_valid = src11_vec4_idx < num_vec4;
|
||||
let is_valid = src11_vec4_idx < ne0_vec4;
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
|
||||
|
||||
@@ -359,6 +359,7 @@ class Keys:
|
||||
CHUNK_SIZE = "clip.audio.chunk_size"
|
||||
CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size"
|
||||
MAX_POS_EMB = "clip.audio.max_pos_emb"
|
||||
FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.audio.attention.head_count"
|
||||
|
||||
@@ -1310,6 +1310,9 @@ class GGUFWriter:
|
||||
def add_audio_max_pos_emb(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value)
|
||||
|
||||
def add_audio_feature_layers(self, layers: Sequence[int]) -> None:
|
||||
self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers)
|
||||
|
||||
def add_audio_projector_window_size(self, value: int) -> None:
|
||||
self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value)
|
||||
|
||||
|
||||
+9
-8
@@ -558,14 +558,15 @@ extern "C" {
|
||||
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
|
||||
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
|
||||
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
|
||||
|
||||
// Get the model's RoPE frequency scaling factor
|
||||
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
|
||||
|
||||
@@ -57,19 +57,25 @@ oppoll=
|
||||
opflt=
|
||||
[ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF"
|
||||
|
||||
opfuse=
|
||||
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
|
||||
|
||||
vmem=
|
||||
[ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM"
|
||||
|
||||
mbuf=
|
||||
[ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB"
|
||||
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
set -x
|
||||
|
||||
adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 1024 -fa on \
|
||||
|
||||
@@ -51,6 +51,12 @@ opqueue=
|
||||
oppoll=
|
||||
[ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP"
|
||||
|
||||
opfuse=
|
||||
[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC"
|
||||
|
||||
mmsel=
|
||||
[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM"
|
||||
|
||||
set -x
|
||||
|
||||
tool=$1; shift
|
||||
@@ -59,5 +65,5 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \
|
||||
$verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \
|
||||
"
|
||||
|
||||
@@ -26,7 +26,7 @@ COL_MAP = {
|
||||
}
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
@@ -93,9 +93,40 @@ def parse_log(file_path, pmu_index=None):
|
||||
+ int(ts_match.group('us'))
|
||||
)
|
||||
|
||||
op_match = op_pattern.search(line)
|
||||
if "|" in line and "profile-op" in line:
|
||||
parts = [p.strip() for p in line.split("|")]
|
||||
prefix = parts[0]
|
||||
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
|
||||
if not prefix_match:
|
||||
continue
|
||||
|
||||
if len(parts) == 7:
|
||||
dims, types, timings = parts[2], parts[3], parts[6]
|
||||
elif len(parts) == 6:
|
||||
dims, types, timings = parts[2], parts[3], parts[5]
|
||||
else:
|
||||
continue
|
||||
|
||||
timing_match = re.search(
|
||||
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
|
||||
timings
|
||||
)
|
||||
if not timing_match:
|
||||
continue
|
||||
|
||||
op_match = timing_match
|
||||
op_name = prefix_match.group("op_name")
|
||||
else:
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
op_name = op_match.group('op_name')
|
||||
dims = op_match.group('dims').strip()
|
||||
types = op_match.group('types').strip()
|
||||
else:
|
||||
op_match = None
|
||||
|
||||
if op_match:
|
||||
pmu_raw = op_match.group('pmu')
|
||||
pmu_raw = op_match.group('pmu') if 'pmu' in op_match.groupdict() else None
|
||||
pmu_val = None
|
||||
if pmu_raw and pmu_index is not None:
|
||||
try:
|
||||
@@ -105,7 +136,7 @@ def parse_log(file_path, pmu_index=None):
|
||||
except (ValueError, IndexError):
|
||||
pmu_val = None
|
||||
|
||||
evt_raw = op_match.group('evt')
|
||||
evt_raw = op_match.group('evt') if 'evt' in op_match.groupdict() else None
|
||||
evt_val = None
|
||||
if evt_raw:
|
||||
try:
|
||||
@@ -122,9 +153,9 @@ def parse_log(file_path, pmu_index=None):
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip(),
|
||||
'types': op_match.group('types').strip(),
|
||||
'name': op_name,
|
||||
'dims': dims,
|
||||
'types': types,
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections import defaultdict
|
||||
logger = logging.getLogger("ggml-hexagon-trace")
|
||||
|
||||
op_pattern = re.compile(
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+?)\s+:\s+(?:(?P<params>.*?)\s+:\s+)?(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
|
||||
)
|
||||
|
||||
trace_pattern = re.compile(
|
||||
@@ -66,7 +66,40 @@ def parse_log(file_path):
|
||||
|
||||
for line in f:
|
||||
line_idx += 1
|
||||
op_match = op_pattern.search(line)
|
||||
if "|" in line and "profile-op" in line:
|
||||
parts = [p.strip() for p in line.split("|")]
|
||||
prefix = parts[0]
|
||||
prefix_match = re.search(r"profile-op\s+(?P<op_name>[A-Z_0-9+]+)", prefix)
|
||||
if not prefix_match:
|
||||
continue
|
||||
|
||||
if len(parts) == 7:
|
||||
dims, types, strides, params, timings = parts[2], parts[3], parts[4], parts[5], parts[6]
|
||||
elif len(parts) == 6:
|
||||
dims, types, strides, params, timings = parts[2], parts[3], parts[4], "", parts[5]
|
||||
else:
|
||||
continue
|
||||
|
||||
timing_match = re.search(
|
||||
r"(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?",
|
||||
timings
|
||||
)
|
||||
if not timing_match:
|
||||
continue
|
||||
|
||||
op_match = timing_match
|
||||
op_name = prefix_match.group("op_name")
|
||||
else:
|
||||
op_match = op_pattern.search(line)
|
||||
if op_match:
|
||||
op_name = op_match.group('op_name')
|
||||
dims = op_match.group('dims').strip() if op_match.group('dims') else ''
|
||||
types = op_match.group('types').strip() if op_match.group('types') else ''
|
||||
strides = op_match.group('strides').strip() if op_match.group('strides') else ''
|
||||
params = op_match.group('params').strip() if ('params' in op_match.groupdict() and op_match.group('params')) else ''
|
||||
else:
|
||||
op_match = None
|
||||
|
||||
if op_match:
|
||||
cycles_start_raw = op_match.group('start')
|
||||
unwrapped_cycles_start = None
|
||||
@@ -77,10 +110,11 @@ def parse_log(file_path):
|
||||
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
|
||||
|
||||
current_op = {
|
||||
'name': op_match.group('op_name'),
|
||||
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
|
||||
'types': op_match.group('types').strip() if op_match.group('types') else '',
|
||||
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
|
||||
'name': op_name,
|
||||
'dims': dims,
|
||||
'types': types,
|
||||
'strides': strides,
|
||||
'params': params,
|
||||
'op_text': op_text,
|
||||
'usec': int(op_match.group('usec')),
|
||||
'cycles': int(op_match.group('cycles')),
|
||||
@@ -397,6 +431,8 @@ def generate_perfetto_trace(filtered_ops, output_path):
|
||||
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
|
||||
if 'strides' in op and op['strides']:
|
||||
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
|
||||
if 'params' in op and op['params'] and op['params'] != '----':
|
||||
debug_annots.append(make_debug_annotation("params", string_val=op['params']))
|
||||
|
||||
# Slice Begin
|
||||
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
|
||||
|
||||
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_nextn_layer_offset(int32_t offset) {
|
||||
cparams.nextn_layer_offset = offset;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
|
||||
ctx->set_embeddings_layer_inp(lid, value);
|
||||
}
|
||||
|
||||
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
|
||||
ctx->set_nextn_layer_offset(offset);
|
||||
}
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
if (!ctx) {
|
||||
return nullptr;
|
||||
|
||||
@@ -115,6 +115,7 @@ struct llama_context {
|
||||
void set_embeddings (bool value);
|
||||
void set_embeddings_nextn(bool value, bool masked);
|
||||
void set_embeddings_layer_inp(uint32_t lid, bool enable);
|
||||
void set_nextn_layer_offset(int32_t offset);
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ struct llama_cparams {
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
int32_t nextn_layer_offset = 0;
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
|
||||
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
|
||||
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
|
||||
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
|
||||
|
||||
// Select which appended NextN block the DECODER_MTP graph runs (offset past
|
||||
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
|
||||
// chain multiple trained NextN heads. Default 0 (first head).
|
||||
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
|
||||
|
||||
// mirrors:
|
||||
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
||||
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
|
||||
|
||||
+9
-2
@@ -682,9 +682,16 @@ struct llm_graph_params {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
|
||||
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
|
||||
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
arch == other.arch &&
|
||||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
|
||||
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
|
||||
return model->hparams.n_layer();
|
||||
}
|
||||
|
||||
int32_t llama_model_n_layer_nextn(const llama_model * model) {
|
||||
return model->hparams.n_layer_nextn;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_head(const llama_model * model) {
|
||||
return model->hparams.n_head();
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user