Compare commits

...

40 Commits

Author SHA1 Message Date
YiChen Lv 5d3eb999b6 metal : per-op source split + parallel compile (#24021)
* preliminary extract common header

* op source split

* split metallib into 8 libs && load in parallel

* derive kernel->library routing from functionNames

* x-macro lib list + underscore filenames, dedup QK_NL, MRC fixes

* op source split 8 to 20

* improve robustness of source fallback

* clean up

* change bool -> atomic_bool

* only prepend headers that source actually includes

* no semaphore, use GCD global queue

* dedup library compile path, fix NSError lifetime, rename gla

* relocate upstream concat/rope_back/repeat kernel changes into split files

* move ggml-common.h from common.h into dequantize.h to shrink binary size

---------

Co-authored-by: lvyichen <lvyichen@stepfun.com>
2026-06-20 13:36:32 +03:00
Georgi Gerganov 5fd2dc2c41 sync : ggml 2026-06-19 10:19:14 +03:00
Georgi Gerganov 1868af13ac ggml : bump version to 0.15.2 (ggml/1548) 2026-06-19 10:19:14 +03:00
Georgi Gerganov 5bd21b8555 pi : remove docs from system prompt (#24791) 2026-06-19 09:34:00 +03:00
Georgi Gerganov 80452d65b9 server : consolidate slot selection into get_available_slot (#24755)
Absorb get_slot_by_id logic into get_available_slot so slot selection
is handled by a single function call. When a specific slot id is
requested, the LCP similarity check still runs to enable proper
prompt cache updates.

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-19 09:22:34 +03:00
shalinib-ibm 8141e730f1 ggml-cpu: support K tails in power10 Q8/Q4 MMA matmul (#24753)
* ggml-cpu: support K tails in Power10 MMA Q8/Q4 matmul

This patch removes the requirement that K be divisible by kc in the tinyBlas_Q0_PPC tiled matmul path. Process the final K panel using its actual depth and pass the reduced panel size through packing and kernel execution.  This allows more workloads to use the MMA kernel and reduces fallback to mnpack.

* Apply suggestion from @taronaeo

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2026-06-19 08:55:38 +03:00
Xuan-Son Nguyen db52540f73 mtmd: add batching support for internvl (#24775) 2026-06-19 01:16:16 +02:00
Pascal 3a3edc9ac6 Ggml/cuda col2im 1d (#24417)
* cuda: add GGML_OP_COL2IM_1D, follow-up to the CPU op

* cuda: col2im_1d use fast_div_modulo for the index decomposition

* cuda: col2im_1d tighten supports_op, type match and contiguous dst
2026-06-18 22:23:01 +02:00
Reguna 40f3aafc45 server: add "X-Accel-Buffering": "no" header to streaming endpoints (#24774)
* server: add "X-Accel-Buffering": "no" header to streaming endpoints

This header tells Nginx (as a reverse proxy) to NOT buffer responses. (only affects streaming endpoints)
Without it, Nginx will break streaming with certain applications (notably the Pi coding harness).
2026-06-18 22:01:24 +02:00
Xuan-Son Nguyen a6b3260a42 mtmd: add batching for mtmd-cli, add video tests (#24778) 2026-06-18 21:55:04 +02:00
o7si 32eddaf2ea cmake : fix ui build with read-only source (#24752) 2026-06-18 18:59:18 +02:00
Xuan-Son Nguyen 060ce1bf72 mtmd: refactor llava-uhd overview image handling (always use ov_img_first) (#24769)
* add dedicated "overview" for mtmd_image_preproc_out

* corrections

* correct (again)

* nits

* nits (2)
2026-06-18 18:53:49 +02:00
Max Krasnyansky d2c67959b3 hexagon: support for op-trace (fine-grain tracing of HVX/HMX/DMA events) (#24592)
* hex-optrace: add support for optrace and instrument matmul and flash-atten code

* hex-trace: improve trace event and prefetto generator

* hex-trace: add new script dedicated to handling traces, specifically perfetto traces

* hex-trace: add --head/--tail options to profile and trace tools

* hex-trace: fix whitespaces

* hex-trace: fix flake8 warnings

* hex-trace: fix flake8 warnings

* hmx-fa: restore q_tiles clearing

* hex-profile: remove circular dep in includes

* hex-trace: simplify trace sizing check

* hex-profile: sort events in the summary by name
2026-06-18 08:35:02 -07:00
Kangjia Gao 7b6c5a2aed docs: fix export-lora --lora-scaled syntax [no release] (#24703)
Assisted-by: Codex
2026-06-18 16:46:17 +02:00
Xuan-Son Nguyen fe7c8b2414 server: (router) fix stopping_thread potentially hang (#24728)
* server: (router) fix stopping_thread potentially hang

* fix windows build
2026-06-18 15:41:09 +02:00
Xuan-Son Nguyen e1efd0991d server: add "schema" and validation (#24150)
* wip

* working

* correct some limits

* add field name to error message
2026-06-18 15:40:58 +02:00
Aarni Koskela 08023072ef server : add last-5-seconds generation speed display (#24291)
* server : add last-5-seconds generation speed display

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-18 14:02:20 +02:00
Amos Wong 20832179e2 ui: provide touch accessible model selection UI (#24604)
* ui : add model selector storybook stories

Covers list, favorites, single-model, all status states
(loading/loaded/sleeping/failed/idle), and selection states.

* ui : improve model selector mobile UX with hover media queries

Use @media (hover:none) to show action buttons directly on touch
devices and color-code them by model status (amber=sleeping,
green=loaded, muted=idle). Status dots hidden on touch. Desktop
hover behavior unchanged.
2026-06-18 13:14:20 +02:00
Anuj Attri 10786217e9 server : return HTTP 400 on invalid grammar (#24144) (#24154)
Throw on grammar parse failure so the server returns HTTP 400
instead of silently dropping the constraint.
Add a regression test for the invalid-grammar response.

Fixes #24144
2026-06-18 12:49:14 +02:00
Xuan-Son Nguyen 552258c535 server: (router) rework -hf preset repo (#24739)
* server: temporary remove HF remote preset

* rework remove preset.ini support

* rm unused get_remote_preset_whitelist()

* print warning

* add docs

* rm stray file
2026-06-18 12:45:23 +02:00
Xuan-Son Nguyen 968c43891a server: fix router args not being forwarded to child instances (#24760) 2026-06-18 12:15:46 +02:00
Xuan-Son Nguyen 24bba7b98e mtmd: refactor preprocessor, add mtmd_image_preproc_out (#24736)
* add mtmd_image_preproc_out

* add dev docs

* remove unused clip API

* rm unused clip_image_f32_batch::grid

* change preprocess() call signature
2026-06-18 12:04:39 +02:00
Neo Zhang 9724f664e8 [SYCL] rename GGML_SYCL_SUPPORT_LEVEL_ZERO (#24719)
* rename GGML_SYCL_SUPPORT_LEVEL_ZERO to GGML_SYCL_SUPPORT_LEVEL_ZERO_API, and GGML_SYCL_ENABLE_LEVEL_ZERO to  GGML_SYCL_USE_LEVEL_ZERO_API

* fix code format

* fix error when rebase
2026-06-18 11:18:26 +03:00
Neo Zhang dd69db2924 sycl : support MUL_MAT and OUT_PROD with Q1_0 (#24721) 2026-06-18 11:17:37 +03:00
Adrien Gallouët 6ec59ddaea app : enable self-update only when built with llama-install.sh (#24754)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-18 09:57:59 +02:00
Sigbjørn Skjæret 32e806b9c1 ci : fix check-release message parsing (#24751) 2026-06-18 09:32:56 +02:00
Neo Zhang 6f1034b32a [SYCL] support OPs: conv_2d, conv_2d_dw, conv2d_transpose (#24600)
* fix conflict

* fix format issue, rename

* rm debug code

* correct the file name
2026-06-18 09:40:03 +03:00
Aleksander Grygier 0b73fc79fe ui: Update code formatting command in pre-commit hook (#24685) 2026-06-18 08:33:50 +02:00
Ravi Panchumarthy 4a79037b8b ci : fix Windows x64 (OpenVINO) release link (#24731) 2026-06-18 08:30:08 +02:00
Georgi Gerganov cae0a3b0b0 metal : check for BF16 support in concat kernel (#24747) 2026-06-18 09:16:06 +03:00
Xuan-Son Nguyen f3e1828164 mtmd: llava_uhd should no longer use batch dim (#24732) 2026-06-17 22:40:50 +02:00
shalinib-ibm 2e88c49c90 ggml-cpu: Conditionally enable power11 backend based on compiler support (#24687)
* ggml: Conditionally enable power11 backend based on compiler support

Guard POWER11 backend creation behind a compiler flag check for -mcpu=power11. This avoids build failures on current GCC/Clang toolchains while preserving forward compatibility once POWER11 support becomes available.

* Update CMakeLists.txt

ggml-cpu: Use -mcpu=power10 for P10 and P11
2026-06-18 02:45:19 +08:00
Georgi Gerganov 0843245cb1 metal : implement rope_back operator (#24725)
Reuse existing rope kernels with a function constant to toggle forward/backward
rotation, avoiding duplicate kernel code.

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-17 20:36:05 +03:00
Georgi Gerganov 8d2e580632 metal : add f16 and bf16 support for concat operator (#24724)
* metal : add f16 and bf16 support for concat operator

Extend the Metal backend concat operator to support f16 and bf16 tensor
types in addition to the existing f32 and i32 support.

- Template kernel_concat on type T with specializations for float, half,
  bfloat, and int
- Add type-specific pipeline getter ggml_metal_library_get_pipeline_concat()
- Update device support check to allow f16 unconditionally and bf16 when
  device supports bfloat16
- Update dispatch to select the correct kernel specialization by type

Assisted-by: pi:llama.cpp/Qwen3.6-27B

* metal : extend concat operator to support f16, bf16, i8, i16 and i64

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-17 19:38:55 +03:00
Xuan-Son Nguyen 4b4d13ae72 server: (router) add model management API (#23976)
* wip

* server: (router) add SSE realtime updates API

* nits

* wip

* add download API

* add download api

* update docs

* add delete endpoint

* fix std::terminate

* fix crash

* fix 2

* add tests

* nits
2026-06-17 18:04:58 +02:00
Dev-iL b4024af6c2 llama : skip main_gpu validation when no devices are available (#23405) 2026-06-17 17:30:26 +03:00
Ruixiang Wang 1a2dea29b9 spec: fix segfault error on long prompts for eagle3 (#24707) 2026-06-17 17:29:49 +03:00
Neo Zhang 74a80dd9c0 [SYCL] add dev2dev memcpy by SYCL API (#24476)
* add dev2dev memcpy by SYCL API

* mv GGML_SYCL_DEV2DEV_MEMCPY to runntime table

* update the detect method for p2p comm

* fix the erro created during fix confilct

---------

Co-authored-by: Neo Zhang <NA>
2026-06-17 17:21:34 +03:00
Neo Zhang d1759e4156 [SYCL] Add conv_3d (#24691)
* add conv_3d

* optimize

* update ops.md

* restore test script

* rm unused code

* rm copyright notes
2026-06-17 17:20:01 +03:00
Julien Chaumond 8086439a4c webui: export conversations as jsonl (#24688)
* webui: export conversations as jsonl

each session is one jsonl file, a session header line followed by one line per message
exporting multiple conversations bundles them into a zip, one jsonl file each

* webui: import jsonl and zip conversation exports

parse the new jsonl session format and zip archives on import
keep supporting the legacy json format
2026-06-17 13:25:47 +02:00
139 changed files with 18487 additions and 14322 deletions
+4 -1
View File
@@ -46,11 +46,13 @@ jobs:
steps:
- id: check
env:
COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
echo "should_release=true" >> $GITHUB_OUTPUT
elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/master" ]]; then
if echo "${{ github.event.head_commit.message }}" | grep -q '\[no release\]'; then
if echo "$COMMIT_MESSAGE" | grep -q '\[no release\]'; then
echo "should_release=false" >> $GITHUB_OUTPUT
else
echo "should_release=true" >> $GITHUB_OUTPUT
@@ -542,6 +544,7 @@ jobs:
steps:
- name: Set OpenVINO version output
id: openvino_version
shell: bash
run: echo "value=${{ env.OPENVINO_VERSION_MAJOR }}" >> $GITHUB_OUTPUT
- name: Clone
-10
View File
@@ -25,13 +25,3 @@ Commits:
- Do not explicitly set the git author in commits - rely on the default git config
- Always use `--no-gpg-sign` when committing
- Never `git push` without explicit confirmation from the user
Resources (read on demand):
- [CONTRIBUTING.md](CONTRIBUTING.md)
- [Build documentation](docs/build.md)
- [Server usage documentation](tools/server/README.md)
- [Server development documentation](tools/server/README-dev.md)
- [PEG parser](docs/development/parsing.md)
- [Auto parser](docs/autoparser.md)
- [Jinja engine](common/jinja/README.md)
- [PR template](.github/pull_request_template.md)
+26 -13
View File
@@ -20,16 +20,21 @@ int llama_fit_params(int argc, char ** argv);
int llama_quantize(int argc, char ** argv);
int llama_perplexity(int argc, char ** argv);
// hands the update over to the install script, which downloads and swaps the binary
// Self-update is only supported for binaries built with llama-install.sh
static int llama_update(int argc, char ** argv) {
(void) argc;
(void) argv;
#ifdef LLAMA_INSTALL_BUILD
#if defined(_WIN32)
return system("powershell -NoProfile -ExecutionPolicy Bypass -Command \"irm https://llama.app/install.ps1 | iex\"");
#else
return system("curl -fsSL https://llama.app/install.sh | sh");
#endif
#else
printf("Updates are available only when installed from https://llama.app\n");
return 1;
#endif
}
static const char * progname;
@@ -46,21 +51,29 @@ struct command {
int (*func)(int, char **);
};
#ifdef LLAMA_INSTALL_BUILD
#define UPDATE_HIDDEN false
#else
#define UPDATE_HIDDEN true
#endif
static const command cmds[] = {
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"update", "Update llama to the latest release", {}, false, llama_update },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
};
#undef UPDATE_HIDDEN
static int version(int argc, char ** argv) {
printf("%s\n", llama_build_info());
return 0;
+34 -78
View File
@@ -285,58 +285,15 @@ static std::string clean_file_name(const std::string & fname) {
return clean_fname;
}
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
GGML_ASSERT(!params.model.hf_repo.empty());
// the returned hf_repo is without tag
auto [hf_repo, hf_tag] = common_download_split_repo_tag(params.model.hf_repo);
// "latest" tag (default if not specified) is translated to "default" preset
if (hf_tag == "latest") {
hf_tag = "default";
}
std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
LOG_TRC("%s: looking for remote preset at %s\n", __func__, preset_url.c_str());
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
if (has_preset) {
LOG_TRC("%s: applying remote preset from %s\n", __func__, preset_url.c_str());
common_preset_context ctx(ex, /* only_remote_allowed */ true);
common_preset global;
auto remote_presets = ctx.load_from_ini(preset_path, global);
remote_presets = ctx.cascade(global, remote_presets);
if (remote_presets.find(hf_tag) != remote_presets.end()) {
common_preset preset = remote_presets.at(hf_tag);
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
preset.apply_to_params(params);
} else {
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(hf_tag) + "] section");
}
} else {
LOG_TRC("%s: no remote preset found, skipping\n", __func__);
}
return has_preset;
}
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
bool found_mtp = false;
common_params_model mtp;
bool found_preset = false;
std::string preset_path;
};
static handle_model_result common_params_handle_model(struct common_params_model & model,
@@ -355,6 +312,12 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts hf_opts = opts;
auto download_result = common_download_model(model, hf_opts);
if (!download_result.preset_path.empty()) {
result.found_preset = true;
result.preset_path = download_result.preset_path;
return result; // skip everything else if preset.ini is used
}
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
}
@@ -454,6 +417,17 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
try {
auto res = common_params_handle_model(params.model, opts);
if (res.found_preset) {
if (!params.models_preset.empty()) {
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
}
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
params.models_preset = res.preset_path;
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
return true;
}
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -601,30 +575,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// maybe handle remote preset
if (!params.model.hf_repo.empty() && !skip_model_download) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
std::string preset_hf_repo = params.model.hf_repo;
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
if (has_preset) {
// re-parse CLI args to override preset values
parse_cli_args();
}
// preserve hf_repo from preset if needed
if (preset_has_hf_repo) {
params.model.hf_repo = preset_hf_repo;
}
}
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
@@ -635,15 +585,21 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// handle model and download
if (!skip_model_download) {
common_params_handle_models(params, ctx_arg.ex);
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
}
if (params.escape) {
+5 -4
View File
@@ -642,10 +642,11 @@ struct common_params {
std::vector<std::string> server_tools;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)
bool log_json = false;
+120 -17
View File
@@ -696,6 +696,7 @@ struct hf_plan {
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
hf_cache::hf_file preset; // if set, only this file is downloaded
};
static hf_plan get_hf_plan(const common_params_model & model,
@@ -717,6 +718,14 @@ static hf_plan get_hf_plan(const common_params_model & model,
return plan;
}
// if preset.ini exists in the repo root, download only that file
for (const auto & f : all) {
if (f.path == "preset.ini") {
plan.preset = f;
return plan;
}
}
hf_cache::hf_file primary;
if (!model.hf_file.empty()) {
@@ -794,14 +803,19 @@ common_download_model_result common_download_model(const common_params_model &
if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
@@ -835,17 +849,22 @@ common_download_model_result common_download_model(const common_params_model &
}
if (is_hf) {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.preset.path.empty()) {
// if preset.ini is used, do not set other paths
result.preset_path = hf_cache::finalize_file(hf.preset);
} else {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
}
} else {
result.model_path = model.path;
@@ -997,3 +1016,87 @@ std::vector<common_cached_model_info> common_list_cached_models() {
return result;
}
bool common_download_remove(const std::string & hf_repo_with_tag) {
namespace fs = std::filesystem;
auto [repo_id, tag] = common_download_split_repo_tag(hf_repo_with_tag);
if (tag.empty()) {
return hf_cache::remove_cached_repo(repo_id);
}
std::string tag_upper = tag;
for (char & c : tag_upper) {
c = (char) std::toupper((unsigned char) c);
}
auto files = hf_cache::get_cached_files(repo_id);
if (files.empty()) {
return false;
}
// collect snapshot entries whose tag matches
std::vector<fs::path> to_remove;
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.tag == tag_upper) {
to_remove.emplace_back(f.local_path);
}
}
if (to_remove.empty()) {
return false;
}
// resolve blob paths from symlinks before deleting snapshot entries
std::vector<fs::path> blobs_to_check;
for (const auto & p : to_remove) {
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
blobs_to_check.push_back((p.parent_path() / target).lexically_normal());
}
}
}
// remove snapshot entries
for (const auto & p : to_remove) {
std::error_code ec;
fs::remove(p, ec);
if (ec) {
LOG_WRN("%s: failed to remove %s: %s\n", __func__, p.string().c_str(), ec.message().c_str());
}
}
if (blobs_to_check.empty()) {
return true;
}
// collect blobs still referenced by remaining snapshot entries
std::unordered_set<std::string> still_referenced;
for (const auto & f : hf_cache::get_cached_files(repo_id)) {
fs::path p(f.local_path);
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
still_referenced.insert((p.parent_path() / target).lexically_normal().string());
}
}
}
// remove orphaned blobs
for (const auto & blob : blobs_to_check) {
if (still_referenced.find(blob.string()) == still_referenced.end()) {
std::error_code ec;
fs::remove(blob, ec);
if (ec) {
LOG_WRN("%s: failed to remove blob %s: %s\n", __func__, blob.string().c_str(), ec.message().c_str());
}
}
}
return true;
}
+8
View File
@@ -63,6 +63,7 @@ struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
std::string preset_path;
};
// throw if the file is missing or invalid (e.g. ETag check failed)
@@ -115,3 +116,10 @@ int common_download_file_single(const std::string & url,
// resolve and download model from Docker registry
// return local path to downloaded model file
std::string common_docker_resolve_model(const std::string & docker);
// Remove a cached model from disk
// input format: "user/model" or "user/model:tag"
// - if tag is omitted, removes the entire repo cache directory
// - if tag is present, removes only files matching that tag (and orphaned blobs)
// returns true if anything was removed
bool common_download_remove(const std::string & hf_repo_with_tag);
+15
View File
@@ -495,4 +495,19 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}
bool remove_cached_repo(const std::string & repo_id) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return false;
}
fs::path repo_path = get_repo_path(repo_id);
std::error_code ec;
auto removed = fs::remove_all(repo_path, ec);
if (ec) {
LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str());
return false;
}
return removed > 0;
}
} // namespace hf_cache
+3
View File
@@ -29,4 +29,7 @@ hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// Remove the entire cached directory for a repo, returns true if removed
bool remove_cached_repo(const std::string & repo_id);
} // namespace hf_cache
+1 -49
View File
@@ -16,48 +16,6 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
// only allow a subset of args for remote presets for security reasons
// do not add more args unless absolutely necessary
// args that output to files are strictly prohibited
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
static const std::set<std::string> allowed_options = {
"model-url",
"hf-repo",
"hf-repo-draft",
"hf-repo-v", // vocoder
"hf-file-v", // vocoder
"mmproj-url",
"pooling",
"jinja",
"batch-size",
"ubatch-size",
"cache-reuse",
"chat-template-kwargs",
"mmap",
// note: sampling params are automatically allowed by default
// negated args will be added automatically if the positive arg is specified above
};
std::set<std::string> allowed_keys;
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
allowed_keys.insert(rm_leading_dashes(arg));
}
for (const auto & env : opt.get_env()) {
allowed_keys.insert(env);
}
}
}
return allowed_keys;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -300,16 +258,10 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
common_preset_context::common_preset_context(llama_example ex)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
// setup allowed keys if only_remote_allowed is true
if (only_remote_allowed) {
filter_allowed_keys = true;
allowed_keys = get_remote_preset_whitelist(key_to_opt);
}
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
+1 -1
View File
@@ -60,7 +60,7 @@ struct common_preset_context {
std::set<std::string> allowed_keys;
// if only_remote_allowed is true, only accept whitelisted keys
common_preset_context(llama_example ex, bool only_remote_allowed = false);
common_preset_context(llama_example ex);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;
+3
View File
@@ -259,6 +259,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
}
if (!grmr && !grammar_str.empty()) {
throw std::runtime_error("failed to parse grammar");
}
// Compute prefill tokens from the generation prompt
std::vector<llama_token> prefill_tokens;
+62 -2
View File
@@ -161,6 +161,64 @@ You could update your test result in it directly.
Please refer to [Docker with SYCL](../docker.md#docker-with-sycl) for details.
## Quick Development WOW
This chapter is for quick development & try with SYCL backend on Intel GPU.
You need to install following sofeware before development:
- Intel GPU driver
- oneAPI package
- other development tools.
Please refer to [Linux](#linux) or [Windows](#windows-1) for above installation and resolve the trouble in usage. There are the detailed guide.
- Linux
```
## build from source code
./examples/sycl/build.sh
## run CONV_2D_DW unit test cases
./build/bin/test-backend-ops -b SYCL0 -o CONV_2D_DW
## run all unit test cases
./build/bin/test-backend-ops -b SYCL0
## run with LLM on the first GPU
./examples/sycl/test.sh -mg 0 -m xxxx.gguf
## run service with LLM on the first GPU
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
./examples/sycl/start-svr.sh -m xxxx.gguf
## update the docs/ops.md for new/update OPs
./examples/sycl/update-ops-doc.sh
```
- Windows
```
## build from source code
examples\sycl\win-build-sycl.bat
## run CONV_2D_DW unit test cases
build\bin\test-backend-ops.exe -b SYCL0 -o CONV_2D_DW
## run all unit test cases
build\bin\test-backend-ops.exe -b SYCL0
## run LLM on the first GPU
examples\sycl\win-test.bat -mg 0 -m xxxx.gguf
## run service with LLM on the first GPU
set ONEAPI_DEVICE_SELECTOR="level_zero:0"
examples\sycl\win-start-svr.bat -m xxxx.gguf
## update the docs/ops.md for new/update OPs
examples\sycl\win-update-ops-doc.bat
```
## Linux
### I. Setup Environment
@@ -701,7 +759,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| GGML_SYCL_HOST_MEM_FALLBACK | ON *(default)* \|OFF *(Optional)* | Allow host memory fallback when device memory is full during quantized weight reorder. Enables inference to continue at reduced speed (reading over PCIe) instead of failing. Requires Linux kernel 6.8+. |
| GGML_SYCL_SUPPORT_LEVEL_ZERO | ON *(default)* \|OFF *(Optional)* | Enable Level Zero API for device memory allocation. Requires Level Zero headers/library at build time and Intel GPU driver (Level Zero runtime) at run time. Reduces system RAM usage during multi-GPU inference. |
| GGML_SYCL_SUPPORT_LEVEL_ZERO_API | ON *(default)* \|OFF *(Optional)* | Support to use Level Zero API for device memory allocation. Requires Level Zero headers/library at build time and Intel GPU driver (Level Zero runtime) at run time. Reduces system RAM usage during multi-GPU inference. SYCL backend always runs on Level Zero running time even if it's set as OFF (The SYCL api will be usage for memory allocation).|
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
@@ -712,10 +770,11 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DEV2DEV_MEMCPY | 0 (default) or 1 | Choose the SYCL or L0 API in dev2dev memory copy.<br>Value: <br>* 0: SYCL API (default)<br>* 1: L0 API -- L0 API is found to lead to abnormal crash in some case. This debug flag is used to check the issue.|
| GGML_SYCL_ENABLE_FLASH_ATTN | 1 (default) or 0| Enable Flash-Attention. It can reduce memory usage. The performance impact depends on the LLM.|
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for Intel devices older than Gen 10) |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
| GGML_SYCL_ENABLE_LEVEL_ZERO | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO=ON at build time. |
| GGML_SYCL_USE_LEVEL_ZERO_API | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO_API=ON at build time. SYCL backend always runs on Level Zero running time even if it's set as OFF (The SYCL api will be usage for memory allocation).|
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| GGML_SYCL_ENABLE_VMM | 0 or 1 (default) | Enable the virtual-memory device pool. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
@@ -731,6 +790,7 @@ Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spo
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
| DEBUG_SYCL_MALLOC | Enable verbose per-call logging of device pool alloc/free operations. |
## Design Rule
- Open to all contributors.
+3 -2
View File
@@ -1,10 +1,11 @@
# Multimodal
llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature:
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-cli](../tools/cli/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
- [llama-mtmd-cli](../tools/mtmd/README.md), for testing and development
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
Currently, we support **image**, **audio** and **video** input.
To enable it, you can use one of the 2 methods below:
+4 -4
View File
@@ -27,11 +27,11 @@ Legend:
| COL2IM_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
+1840 -1840
View File
File diff suppressed because it is too large Load Diff
+36 -38
View File
@@ -8,55 +8,53 @@ The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/lla
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
### Using a Remote Preset
### Using a Hugging Face Preset
> [!NOTE]
> [!IMPORTANT]
>
> This feature is currently only supported via the `-hf` option.
> Please only use presets that you can trust! Unknown presets may be unsafe
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
You can push your preset to Hugging Face Hub and share with other users by:
1. Creating an empty model repository on Hugging Face
2. Creating a `preset.ini` file in the root directory of the repository
Example:
Example of a `preset.ini`:
```ini
hf-repo-draft = username/my-draft-model-GGUF
temp = 0.5
top-k = 20
top-p = 0.95
[*]
ctx-size = 0
mmap = 1
kv-unified = 1
parallel = 4
spec-default = 1
[Qwen3.5-4B]
hf = unsloth/Qwen3.5-4B-GGUF:Q4_K_M
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
[gpt-oss-120b-hf]
hf = ggml-org/gpt-oss-120b-GGUF
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
chat-template-kwargs = {"reasoning_effort": "high"}
```
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
Example usage:
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
```sh
llama-cli -hf username/my-model-with-preset
# This is equivalent to:
llama-cli -hf username/my-model-with-preset \
--hf-repo-draft username/my-draft-model-GGUF \
--temp 0.5 \
--top-k 20 \
--top-p 0.95
```
You can also override preset arguments by specifying them on the command line:
The preset will be loaded similarly to the `--models-preset` option. Therefore, you can also override certain params via CLI arguments:
```sh
# Force temp = 0.1, overriding the preset value
llama-cli -hf username/my-model-with-preset --temp 0.1
```
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
```ini
hf-repo = user/my-model-main
hf-repo-draft = user/my-model-draft
temp = 0.8
ctx-size = 1024
; (and other configurations)
llama-cli -hf username/my-preset --temp 0.1
```
### Named presets
+9
View File
@@ -0,0 +1,9 @@
#!/bin/bash
# MIT license
# Copyright (C) 2026 Intel Corporation
# SPDX-License-Identifier: MIT
./build/bin/test-backend-ops support --output csv > docs/ops/SYCL.csv
./scripts/create_ops_docs.py
+8
View File
@@ -0,0 +1,8 @@
@echo off
rem MIT license
rem Copyright (C) 2026 Intel Corporation
rem SPDX-License-Identifier: MIT
build\bin\test-backend-ops support --output csv > docs\ops\SYCL.csv
python scripts\create_ops_docs.py
+2 -5
View File
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 15)
set(GGML_VERSION_PATCH 1)
set(GGML_VERSION_PATCH 2)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@@ -249,7 +249,7 @@ option(GGML_SYCL "ggml: use SYCL"
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON)
option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON)
option(GGML_SYCL_SUPPORT_LEVEL_ZERO_API "ggml: use Level Zero API in SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
@@ -341,9 +341,6 @@ set(GGML_PUBLIC_HEADERS
include/gguf.h)
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
+8 -1
View File
@@ -438,7 +438,14 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
# POWER11 backend: only if compiler supports -mcpu=power11
check_cxx_compiler_flag("-mcpu=power11" GGML_CXX_SUPPORTS_POWER11)
if (GGML_CXX_SUPPORTS_POWER11)
message(STATUS "Compiler supports -mcpu=power11, enabling POWER11 backend")
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
else()
message(STATUS "Skipping POWER11 backend: compiler does not support -mcpu=power11")
endif()
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
+1 -1
View File
@@ -389,7 +389,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
if (EXTRACTED_NUMBER EQUAL 10 OR EXTRACTED_NUMBER EQUAL 11)
list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9)
+6 -5
View File
@@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC {
else if (n_aligned % 16 == 0) nc = 16;
else nc = 8;
}
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
bool can_use_tiled = n_aligned > 0 && (m % mc == 0);
if (can_use_tiled) {
matmul_tiled(m, n_aligned, mc, nc, kc);
if (n > n_aligned) {
@@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
int64_t k_cur = MIN(kc, k - kk);
if constexpr(is_Ablock_q4) {
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
} else {
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
}
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack);
}
}
}
+81
View File
@@ -0,0 +1,81 @@
#include "col2im-1d.cuh"
#include "convert.cuh"
// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach)
// columns: [K*OC, T_in] -> output: [T_out, OC]
// Supports F32, F16, BF16 data with F32 accumulator.
template <typename T>
static __global__ void col2im_1d_kernel(
const T * __restrict__ col,
T * __restrict__ dst,
const int T_in, const uint3 T_out_fd,
const int OC, const int K, const int K_OC,
const int s0, const int p0, const int total) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= total) return;
// dst layout: [T_out, OC], ne[0]=T_out fastest
const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out
const int oc = (int)qr.x;
const int t_out = (int)qr.y;
const int t_abs = t_out + p0; // absolute position in uncropped signal
// Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K
int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s)
if (t_in_min < 0) t_in_min = 0;
int t_in_max = t_abs / s0;
if (t_in_max >= T_in) t_in_max = T_in - 1;
float sum = 0.0f;
for (int t_in = t_in_min; t_in <= t_in_max; t_in++) {
const int k = t_abs - t_in * s0;
// col layout: [K*OC, T_in], column index = oc * K + k
sum += ggml_cuda_cast<float>(col[(oc * K + k) + t_in * K_OC]);
}
dst[idx] = ggml_cuda_cast<T>(sum);
}
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int K_OC = (int) src0->ne[0];
const int T_in = (int) src0->ne[1];
const int K = K_OC / OC;
const int T_out = (int) dst->ne[0];
const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out);
const int total = T_out * OC;
const int block_size = 256;
const int num_blocks = (total + block_size - 1) / block_size;
switch (src0->type) {
case GGML_TYPE_F32: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const float *)src0->data, (float *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
case GGML_TYPE_F16: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const half *)src0->data, (half *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
case GGML_TYPE_BF16: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
default:
GGML_ABORT("col2im_1d: unsupported type");
}
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+12
View File
@@ -11,6 +11,7 @@
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/col2im-1d.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
@@ -3051,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
case GGML_OP_COL2IM_1D:
ggml_cuda_op_col2im_1d(ctx, dst);
break;
case GGML_OP_POOL_2D:
ggml_cuda_op_pool2d(ctx, dst);
break;
@@ -5316,6 +5320,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
return false;
} break;
case GGML_OP_COL2IM_1D:
{
ggml_type src0_type = op->src[0]->type;
return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) &&
op->type == src0_type &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op);
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
+104 -12
View File
@@ -69,6 +69,7 @@ static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
static int opt_opbatch = 1024; // max number of ops in a batch
static int opt_opqueue = 16; // max number of pending batches
static int opt_oppoll = 0; // polling for batch completions
static int opt_optrace = 0; // trace buffer size per thread (0 means default)
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
@@ -118,20 +119,39 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no");
}
static const char * htp_event_name(uint16_t id) {
switch (id) {
case HTP_TRACE_EVT_DMA: return "DMA";
case HTP_TRACE_EVT_HVX_COMP: return "HVX_COMP";
case HTP_TRACE_EVT_HVX_A_QUANT: return "HVX_A_QUANT";
case HTP_TRACE_EVT_HVX_A_PREP: return "HVX_A_PREP";
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
}
static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node,
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
const htp_prof_desc & pd) {
if (!opt_profile) return;
uint32_t op_usec = pd.usecs;
uint32_t op_cycles = pd.cycles_stop - pd.cycles_start;
const uint32_t * pmu = pd.pmu;
char pmu_str[256] = "";
if (opt_profile > 1) {
if (opt_profile == 2) {
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]",
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
}
htp_opformat fmt(node);
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str);
float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f;
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str);
}
// ** backend sessions
@@ -1995,10 +2015,16 @@ struct ggml_hexagon_opqueue {
size_t n_ops = batch_size;
size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS;
size_t tr_size = 0;
if (opt_profile == 3) {
tr_size = (HTP_MAX_NTHREADS + 1) * opt_optrace * sizeof(htp_trace_desc);
}
shm_blk_size = sizeof(htp_buf_desc) * n_bufs +
sizeof(htp_tensor) * n_tensors +
sizeof(htp_op_desc) * n_ops +
sizeof(htp_prof_desc) * n_ops;
sizeof(htp_prof_desc) * n_ops +
tr_size;
shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */);
@@ -2042,11 +2068,19 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * req.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * req.n_ops;
size_t tr_size = 0;
if (opt_profile == 3) {
req.n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(htp_trace_desc);
} else {
req.n_traces = 0;
}
dbuf.ptr = shm_buf->base + (req.id * shm_blk_size);
dbuf.fd = shm_buf->fd;
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base;
dbuf.size = b_size + t_size + o_size + p_size;
dbuf.size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(dbuf.size <= shm_blk_size);
@@ -2092,7 +2126,14 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops;
const size_t m_size = b_size + t_size + o_size + p_size;
size_t tr_size = 0;
uint32_t n_traces = 0;
if (opt_profile == 3) {
n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * n_traces * sizeof(htp_trace_desc);
}
const size_t m_size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(m_size <= shm_blk_size);
HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n",
@@ -2111,13 +2152,62 @@ struct ggml_hexagon_opqueue {
GGML_ASSERT(rsp.n_ops <= ops.size());
const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr;
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu);
const htp_trace_desc * trace_events = nullptr;
if (opt_profile == 3) {
trace_events = (const htp_trace_desc *) (p_ptr + p_size);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec);
uint32_t trace_idx[HTP_MAX_NTHREADS + 1] = {0};
uint32_t valid_cnt[HTP_MAX_NTHREADS + 1] = {0};
if (opt_profile == 3) {
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
uint32_t count = rsp.n_traces[t];
valid_cnt[t] = count > n_traces ? n_traces : count;
}
}
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i]);
if (opt_profile == 3) {
uint32_t op_duration = pd[i].cycles_stop - pd[i].cycles_start;
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
while (trace_idx[t] < valid_cnt[t]) {
const auto & e = trace_events[t * n_traces + trace_idx[t]];
uint32_t offset = e.cycles - pd[i].cycles_start;
if (offset >= 0x80000000) {
trace_idx[t]++;
continue;
}
if (offset > op_duration) {
break;
}
bool is_stop = (e.info & 0x8000) != 0;
uint16_t info = e.info & 0x7FFF;
GGML_LOG_DEBUG("ggml-hex: %s trace-op %s: thread %u event %s info %u %s %u\n",
shm_buf->sess->c_name(), ops[i].op_name().c_str(), t, htp_event_name(e.id), info, is_stop ? "stop" : "start", e.cycles);
trace_idx[t]++;
}
}
}
}
char evt_str[256] = "";
if (opt_profile == 3) {
sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]",
rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3],
rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7],
rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u%s\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec, evt_str);
}
}
};
@@ -3901,6 +3991,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE");
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
@@ -3939,6 +4030,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128);
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
+1 -1
View File
@@ -37,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
@@ -339,6 +339,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
if (ir0 >= ir1) return;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
@@ -615,6 +618,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
+10 -3
View File
@@ -6,6 +6,8 @@
#include <stdbool.h>
#include <stdint.h>
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
#endif
@@ -88,6 +90,7 @@ typedef struct {
uint32_t pop_idx;
uint32_t capacity;
uint32_t idx_mask;
struct htp_thread_trace * trace;
} dma_queue;
dma_queue * dma_queue_create(size_t capacity);
@@ -152,6 +155,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (size) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
@@ -202,6 +206,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (nrows) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = desc;
} else {
@@ -223,10 +228,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
dmpoll();
if (!desc->done) {
while (!desc->done) {
dmpoll();
}
}
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_DMA, q->pop_idx);
dptr = q->dptr[q->pop_idx];
+64
View File
@@ -0,0 +1,64 @@
#ifndef HEX_PROFILE_H
#define HEX_PROFILE_H
#include <stdbool.h>
#include <stdint.h>
#include <qurt.h>
#include "hex-utils.h"
#include "htp-ops.h"
#define HTP_TRACE_EVT_START 0
#define HTP_TRACE_EVT_STOP 1
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
#endif
}
struct htp_thread_trace {
uint32_t count;
uint32_t max_events;
struct htp_trace_desc * events;
};
static inline void htp_trace_event(struct htp_thread_trace * tr, uint16_t id, uint16_t info, uint32_t type) {
if (tr && tr->events && tr->count < tr->max_events) {
uint32_t idx = tr->count;
tr->events[idx].id = id;
tr->events[idx].info = info | (type == HTP_TRACE_EVT_STOP ? 0x8000 : 0);
tr->events[idx].cycles = (uint32_t) hex_get_cycles();
tr->count++;
}
}
static inline void htp_trace_event_start(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_START);
}
static inline void htp_trace_event_stop(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_STOP);
}
#endif /* HEX_PROFILE_H */
-27
View File
@@ -107,31 +107,4 @@ static inline void hex_pause() {
asm volatile(" pause(#255)\n");
}
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
// qurt_pmu_get_pmucnt(counters);
#endif
}
#endif /* HEX_UTILS_H */
+26 -66
View File
@@ -18,7 +18,7 @@
#include "ggml-common.h"
#include "hex-dma.h"
#include "hex-fastdiv.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "hmx-queue.h"
#include "hmx-utils.h"
#include "htp-ctx.h"
@@ -367,8 +367,11 @@ static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
(int) args->src_stride, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
@@ -408,8 +411,11 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
(int) args->src_stride, (int) args->n_col_tiles, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
@@ -462,6 +468,9 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * q = args->q;
const uint32_t q_start = args->q_start;
const uint32_t kv_head = args->kv_head;
@@ -515,6 +524,7 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_q_load(struct hmx_fa_context * factx,
@@ -566,6 +576,9 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * dst = args->dst;
const __fp16 * o_tile_src = args->o_tile_src;
const uint32_t q_start = args->q_start;
@@ -611,6 +624,7 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_o_store(struct hmx_fa_context * factx,
@@ -680,6 +694,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
// Per-thread row scratch: thread i uses bufs at offset i * 2 * stride
const size_t row_buf_stride = factx->row_buf_stride;
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
@@ -950,6 +967,7 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
}
// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads).
@@ -1245,6 +1263,7 @@ static __attribute__((noinline)) void fa_compute_slopes(
// ============================================================================
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[HTP_MAX_NTHREADS] : NULL;
const struct htp_tensor * q = octx->src[0];
const struct htp_tensor * k = octx->src[1];
const struct htp_tensor * v = octx->src[2];
@@ -1422,19 +1441,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
// Profiling timers
TIMER_DEFINE(total);
TIMER_DEFINE(q_load);
TIMER_DEFINE(kv_dma);
TIMER_DEFINE(k_interleave);
TIMER_DEFINE(v_interleave);
TIMER_DEFINE(qk_dot);
TIMER_DEFINE(softmax);
TIMER_DEFINE(o_update);
TIMER_DEFINE(o_norm);
TIMER_DEFINE(o_store);
TIMER_START(total);
// ======== DMA setup ========
dma_queue * const dma = ctx->dma[0];
@@ -1474,12 +1480,10 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
// ---- Load Q block [g_br, D] -> tiles, interleaving G heads ----
TIMER_START(q_load);
if (n_rows_g < g_br) {
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
}
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(q_load);
// ---- Initialize per-block state ----
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
@@ -1558,10 +1562,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
// Wait for current KV DMA
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
// Push mask DMA for this block (single 2D DMA when broadcast)
bool has_mask_dma = false;
@@ -1583,10 +1585,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.DV = DV;
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
}
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
qk_job.q_tiles = factx.vtcm_q_tiles;
@@ -1597,15 +1596,11 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
qk_job.n_dot_tiles = DK / 32;
qk_job.n_tiles_per_bc = n_tiles_per_bc;
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
TIMER_START(qk_dot);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
// DMA push next block (non-blocking, before worker_pool)
DMA_PREFETCH_KV(kv_blk + 1);
TIMER_START(v_interleave);
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// Pop and swap previous block's output update (deferred HMX pop)
if (kv_blk > 0) {
@@ -1615,7 +1610,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// Pop current block's dot product job
hmx_queue_pop(hmx_q);
TIMER_STOP(qk_dot);
// ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ----
// Pop mask DMA before softmax (ensures VTCM buffer is ready)
@@ -1641,10 +1635,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
buf_idx = 1 - buf_idx;
} // end KV block loop (pipeline)
@@ -1664,11 +1655,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
ou_job.n_tiles_per_bc = n_tiles_per_bc;
ou_job.DV = DV;
TIMER_START(o_update);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
hmx_queue_pop(hmx_q);
TIMER_STOP(o_update);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
@@ -1683,23 +1671,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const uint32_t kv_start = kv_blk * Bc;
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
bool has_mask_dma = false;
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
DMA_PREFETCH_KV(kv_blk + 1);
// K interleave (multi-thread HVX)
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// QK dot (inline HMX on main thread)
TIMER_START(qk_dot);
{
const size_t n_dot_tiles = (size_t) (DK / 32);
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
@@ -1709,6 +1688,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
@@ -1724,8 +1704,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(qk_dot);
// Pop mask DMA
MASK_DMA_POP(has_mask_dma);
@@ -1751,21 +1731,9 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
// V interleave (multi-thread HVX)
TIMER_START(v_interleave);
// FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout
// stride to match o_update's v_tile access. Using per-block n_col_tiles
// misplaces DV_tile 1..3 in the last partial KV block.
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// O update (inline HMX on main thread)
TIMER_START(o_update);
{
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
@@ -1777,6 +1745,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1798,16 +1767,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
TIMER_STOP(o_update);
buf_idx = 1 - buf_idx;
} // end KV block loop (fallback)
}
// ---- Final normalization: O = diag(1/l) @ O ----
TIMER_START(o_norm);
{
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
@@ -1830,6 +1798,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1842,14 +1811,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
}
TIMER_STOP(o_norm);
// ---- Store O block ----
TIMER_START(o_store);
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(o_store);
#undef MASK_DMA_PUSH
#undef MASK_DMA_POP
@@ -1865,14 +1832,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
#endif
return HTP_STATUS_OK;
}
+55 -41
View File
@@ -27,7 +27,7 @@
#include "hmx-ops.h"
#include "hmx-utils.h"
#include "hmx-queue.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "vtcm-utils.h"
@@ -430,6 +430,7 @@ typedef struct {
int n_tasks;
int n_k_tiles;
struct fastdiv_values n_k_tiles_div;
struct htp_thread_trace * traces;
} x4x2_dequantize_state_t;
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
@@ -533,11 +534,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(
\
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
int start = task_id * state->n_tiles_per_task; \
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
} \
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
}
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
@@ -657,11 +661,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
@@ -717,11 +724,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void convert_f16_weight_to_fp16_tiles_task(
@@ -773,11 +783,14 @@ static void convert_f16_weight_to_fp16_tiles_task(
static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
convert_f16_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void quantize_f32_weight_to_fp16_tiles_task(
@@ -833,11 +846,14 @@ static void quantize_f32_weight_to_fp16_tiles_task(
static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
quantize_f32_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
@@ -868,6 +884,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
state.weight_type = weight_type;
state.n_k_tiles = n_k_tiles;
state.n_k_tiles_div = n_k_tiles_div;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
dequant_worker_fn(1, 0, &state);
@@ -985,10 +1002,13 @@ typedef struct {
int n_chunks_per_task;
int n_cols;
int n; // DDR row stride (total output columns)
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
int chunk_idx = task_id * st->n_chunks_per_task;
@@ -998,6 +1018,7 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
@@ -1015,6 +1036,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
state.vtcm_src = vtcm_src;
state.n_cols = n_cols;
state.n = n;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_output_chunk_worker_fn(1, 0, &state);
@@ -1086,10 +1108,13 @@ typedef struct {
int n_chunks_per_task;
int k_block;
int k_stride;
struct htp_thread_trace * traces;
} activation_transfer_task_state_t;
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
// one chunk: one row
@@ -1100,6 +1125,7 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i,
const float *src = st->src + chunk_idx * st->k_stride;
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) {
@@ -1117,6 +1143,7 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
state.src = src;
state.k_block = k_block;
state.k_stride = k_stride;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_activation_chunk_worker_fn(1, 0, &state);
@@ -1245,13 +1272,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
@@ -1370,7 +1391,12 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
// C: HMX Compute (Synchronous)
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
// D: Output Store
float *output_chunk = dst + (mr * n + nc);
@@ -1380,18 +1406,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
if (!use_pipeline) {
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
size_t weight_size = (size_t)n * row_stride;
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
}
#endif
return 0;
}
@@ -1523,13 +1538,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
m_chunk_n_rows, n_chunk_n_cols,
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
@@ -1549,7 +1558,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
// contiguous rows into a VTCM scratch buffer first, then HVX
// converts from the contiguous VTCM buffer. This avoids L2 cache
// thrashing from HVX loads at large strides.
TIMER_START(activation_load);
for (int g = 0; g < group_size; ++g) {
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
@@ -1569,7 +1577,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
params->k, params->act_stride, ctx->n_threads);
}
}
TIMER_STOP(activation_load);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
@@ -1584,7 +1591,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
@@ -1601,24 +1607,22 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
0, n_cols);
hex_swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
// Reuse the interleaved weight for every q_head in this GQA group
for (int g = 0; g < group_size; ++g) {
TIMER_START(hmx_core);
{
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
params->k / 32);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads);
}
TIMER_STOP(output_store);
}
}
}
@@ -1627,14 +1631,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
params->m, params->k, params->n, group_size);
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
#endif
return 0;
}
@@ -1668,6 +1665,7 @@ typedef struct {
size_t nb12;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} activation_transfer_gathered_task_state_t;
typedef struct {
@@ -1684,6 +1682,7 @@ typedef struct {
size_t dst_nb2;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} output_transfer_scattered_task_state_t;
static void transfer_activation_chunk_fp32_to_fp16_gathered(
@@ -1780,6 +1779,9 @@ static void transfer_activation_chunk_fp32_to_fp16_gathered(
static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_gathered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1791,6 +1793,7 @@ static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigne
st->matrix_rows, st->cur_a, st->mapping_stride,
st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_gathered_threaded(
@@ -1830,6 +1833,7 @@ static void transfer_activation_chunk_gathered_threaded(
.nb12 = nb12,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -1895,6 +1899,9 @@ static void transfer_output_chunk_fp16_to_fp32_scattered(
static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_scattered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1906,6 +1913,7 @@ static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned i
st->matrix_rows, st->cur_a, st->mapping_stride,
st->dst_nb1, st->dst_nb2, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_scattered_threaded(
@@ -1942,6 +1950,7 @@ static void transfer_output_chunk_scattered_threaded(
.dst_nb2 = dst_nb2,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -2053,7 +2062,12 @@ int hmx_matmul_id_2d_f32(struct htp_context *ctx,
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
transfer_output_chunk_scattered_threaded(
ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols,
-34
View File
@@ -1,34 +0,0 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H
+2
View File
@@ -44,7 +44,9 @@ static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
default:
hmx_lock(q);
htp_trace_event_start(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
d->func(d->data);
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
break;
}
+2
View File
@@ -11,6 +11,7 @@
#include <HAP_farf.h>
#include "hex-utils.h"
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
@@ -47,6 +48,7 @@ struct hmx_queue {
void * stack;
uint32_t hap_rctx;
bool hmx_locked;
struct htp_thread_trace * trace;
};
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
+2
View File
@@ -4,6 +4,7 @@
#include "hex-dma.h"
#include "hmx-queue.h"
#include "htp-ops.h"
#include "hex-profile.h"
#include "worker-pool.h"
#include <assert.h>
@@ -70,6 +71,7 @@ struct htp_context {
bool hmx_enabled;
bool etm;
uint32_t profiler;
struct htp_thread_trace trace[HTP_MAX_NTHREADS + 1];
uint8_t * vtcm_base;
size_t vtcm_size;
+31 -4
View File
@@ -146,10 +146,36 @@ struct htp_op_desc {
uint16_t dst; // Output tensor index
};
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_TRACE_MAX_EVENTS 256
enum htp_profiler_mode {
HTP_PROF_DISABLED = 0,
HTP_PROF_BASIC = 1,
HTP_PROF_PMU = 2,
HTP_PROF_TRACE = 3,
};
enum htp_trace_event_id {
HTP_TRACE_EVT_DMA = 0,
HTP_TRACE_EVT_HVX_COMP = 20,
HTP_TRACE_EVT_HVX_A_QUANT = 21,
HTP_TRACE_EVT_HVX_A_PREP = 22,
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HMX_COMP = 40,
};
struct htp_trace_desc {
uint32_t cycles; // lower 32-bits of cycle counter
uint16_t id; // Event ID
uint16_t info; // bit 15: is_stop. bits 14-0: tile/chunk index or other metadata.
};
#define HTP_PROF_PMU_NCNT 8
@@ -158,8 +184,8 @@ enum htp_profiler_mode {
struct htp_prof_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t usecs; // Number of usec
uint32_t cycles; // Number of cycles
uint32_t pad; // Unused
uint32_t cycles_start; // Start cycle counter
uint32_t cycles_stop; // Stop cycle counter
uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters
};
@@ -168,7 +194,7 @@ struct htp_opbatch_req {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of ops
uint32_t flags; // unused
uint32_t n_traces; // Number of trace descriptors per thread
uint32_t pad; // unused
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
// struct htp_tensor tensors[]; -- dspqueue buf 0
@@ -181,7 +207,8 @@ struct htp_opbatch_rsp {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of op profile descriptors
uint32_t pad; // unused
uint32_t n_traces[HTP_MAX_NTHREADS + 1];
uint8_t pad[8]; // align to 8 bytes
// struct htp_prof_desc profs[]; -- dspqueue buf 0
};
+41 -9
View File
@@ -400,7 +400,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->hmx_queue = NULL;
if (use_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (!ctx->hmx_queue) {
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
} else {
FARF(ERROR, "hmx-queue-create failed");
ctx->hmx_enabled = false;
}
@@ -425,6 +427,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->n_threads = n_hvx;
for (int i = 0; i < ctx->n_threads; i++) {
ctx->dma[i] = dma_queue_create(256); // queue depth
if (ctx->dma[i]) {
ctx->dma[i]->trace = &ctx->trace[i];
}
}
ctx->ddr_spad_size = 512 * 1024; // 512 KB
@@ -502,7 +507,8 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) {
struct profile_data {
uint64_t usecs;
uint64_t cycles;
uint64_t cycles_start;
uint64_t cycles_stop;
uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS];
};
@@ -512,8 +518,9 @@ static inline void profile_start(uint32_t mode, struct profile_data * d) {
hex_get_pmu(d->pmu_counters);
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_get_qtimer_count();
d->cycles = hex_get_cycles();
d->cycles_start = hex_get_cycles();
break;
default:
break;
@@ -530,8 +537,9 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
}
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
d->cycles = hex_get_cycles() - d->cycles;
d->cycles_stop = hex_get_cycles();
break;
default:
break;
@@ -845,14 +853,15 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
const uint32_t t_size = sizeof(struct htp_tensor) * n_tens;
const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops;
const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops;
const uint32_t tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(struct htp_trace_desc);
if (dbuf.size < b_size + t_size + o_size + p_size) {
FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size);
if (dbuf.size < b_size + t_size + o_size + p_size + tr_size) {
FARF(ERROR, "invalid opbatch memory block size %u (req %u)", dbuf.size, b_size + t_size + o_size + p_size + tr_size);
break;
}
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size);
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u n-traces %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, req.n_traces, dbuf.size, b_size, t_size, o_size);
// Setup descriptor pointers
uint8_t * m_ptr = dbuf.ptr;
@@ -869,6 +878,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
octx->n_threads = ctx->n_threads;
octx->ctx = ctx;
if (ctx->profiler == HTP_PROF_TRACE) {
memset(ctx->trace, 0, sizeof(ctx->trace));
struct htp_trace_desc * trace_events = (struct htp_trace_desc *) (m_ptr + p_size);
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = &trace_events[t * req.n_traces];
ctx->trace[t].max_events = req.n_traces;
}
} else {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = NULL;
ctx->trace[t].max_events = 0;
}
}
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
@@ -886,7 +909,8 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
pds[i].cycles = prof.cycles;
pds[i].cycles_start = prof.cycles_start;
pds[i].cycles_stop = prof.cycles_stop;
for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) {
pds[i].pmu[j] = prof.pmu_counters[j];
}
@@ -899,6 +923,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
memset(rsp.pad, 0, sizeof(rsp.pad));
if (ctx->profiler == HTP_PROF_TRACE) {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
rsp.n_traces[t] = ctx->trace[t].count;
}
} else {
memset(rsp.n_traces, 0, sizeof(rsp.n_traces));
}
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
+46
View File
@@ -3350,6 +3350,7 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
@@ -3411,10 +3412,12 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
}
}
}
@@ -3430,6 +3433,7 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
@@ -3477,6 +3481,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Process src1 columns in pairs (2×2 tiling)
uint32_t ir1 = 0;
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
@@ -3494,6 +3500,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
@@ -3511,12 +3519,14 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
t2 = HAP_perf_get_qtimer_count();
@@ -3530,6 +3540,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
// q8x4x2 src1 tensor is already in VTCM spad
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01;
@@ -3581,7 +3592,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3599,7 +3612,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 2);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 2;
}
if (ir0 < src0_end_row) {
@@ -3607,7 +3622,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 1;
}
} else {
@@ -3627,7 +3644,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3645,7 +3664,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3669,6 +3690,7 @@ struct mmid_row_mapping {
// src1 tensor is already in VTCM spad
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3735,6 +3757,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3746,6 +3769,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3764,6 +3788,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3775,6 +3800,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3789,6 +3815,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3847,7 +3874,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3865,7 +3894,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -4147,6 +4178,7 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4163,6 +4195,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4189,6 +4222,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
@@ -4219,6 +4253,7 @@ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y,
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4235,6 +4270,7 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4260,11 +4296,13 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4281,6 +4319,7 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4301,11 +4340,13 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4322,6 +4363,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4342,12 +4384,14 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
// TODO just a plain copy that should be done via the DMA during the Op setup
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4364,6 +4408,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4384,6 +4429,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
+119 -51
View File
@@ -24,62 +24,119 @@ if (GGML_METAL_NDEBUG)
endif()
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
set(METALLIB_KERNEL_SOURCES
kernels/fa.metal
kernels/mul_mv.metal
kernels/mul_mm.metal
kernels/quantize.metal
kernels/softmax.metal
kernels/norm.metal
kernels/unary.metal
kernels/binbcast.metal
kernels/reduce.metal
kernels/tri.metal
kernels/ssm.metal
kernels/wkv.metal
kernels/gated_delta_net.metal
kernels/solve_tri.metal
kernels/rope.metal
kernels/conv.metal
kernels/upscale.metal
kernels/argsort.metal
kernels/pool.metal
kernels/misc.metal
)
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(METALLIB_EMBED_ASM_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(kind ${src} NAME_WE)
# symbol names must be valid C identifiers ('-' is not allowed)
string(REPLACE "-" "_" kind_sym ${kind})
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
COMMAND echo "Embedding Metal library"
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
COMMENT "Generate assembly for embedded Metal library"
VERBATIM
)
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
# only prepend headers that this source actually includes
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
if(_has_dequantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
endif()
if(_has_quantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
endif()
add_custom_command(
OUTPUT "${ASM}"
# Step 1: concatenate shared headers + this kernel source
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
# Step 2: remove internal #include and #pragma once
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
# Step 4: inline ggml-metal-impl.h
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
# Step 5: emit an asm chunk with kind-specific start/end symbols
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
# section name is limited to 16 chars so we keep it shared
# across kinds (__ggml_metallib) and only vary the global symbols.
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
DEPENDS ../ggml-common.h ggml-metal-impl.h
kernels/common.h kernels/dequantize.h kernels/quantize.h
kernels/${kind}.metal
COMMENT "Generate embedded Metal library for ${kind}"
VERBATIM
)
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
endforeach()
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
else()
# copy metal files to bin directory
# copy header files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
foreach(src ${METALLIB_KERNEL_SOURCES})
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
endforeach()
if (GGML_METAL_SHADER_DEBUG)
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
#
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
# disabling fast math is needed in order to pass tests/test-backend-ops
# note: disabling fast math is needed in order to pass tests/test-backend-ops
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
# note: adding -g causes segmentation fault during compile
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
set(XC_FLAGS -fno-fast-math -fno-inline)
else()
set(XC_FLAGS -O3)
endif()
# Append macOS metal versioning flags
if (GGML_METAL_MACOSX_VERSION_MIN)
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
@@ -90,35 +147,46 @@ else()
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
endif()
# Compile each kernel source to .air, then link into default.metallib
set(AIR_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(name ${src} NAME_WE)
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
list(APPEND AIR_FILES ${AIR})
add_custom_command(
OUTPUT ${AIR}
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
COMMENT "Compiling ${src}"
VERBATIM
)
endforeach()
add_custom_command(
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
COMMENT "Compiling Metal kernels"
)
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
DEPENDS ${AIR_FILES}
COMMENT "Linking Metal kernels into default.metallib"
)
# FIXME: only add to the ggml-metal target?
add_custom_target(
ggml-metal-lib ALL
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
)
)
endif() # GGML_METAL_EMBED_LIBRARY
if (NOT GGML_METAL_EMBED_LIBRARY)
install(
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
+20 -3
View File
@@ -66,7 +66,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml
const char * op_str = "undefined";
switch (op) {
case GGML_OP_ADD_ID: op_str = "add_id"; break;
case GGML_OP_CONCAT: op_str = "concat"; break;
default: GGML_ABORT("fatal error");
};
@@ -211,6 +210,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat(ggml_metal_library_t lib, ggml_type tsrc) {
char base[256];
char name[256];
snprintf(base, 256, "kernel_concat_%s", ggml_type_name(tsrc));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
@@ -1689,7 +1703,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ROPE);
assert(op->op == GGML_OP_ROPE || op->op == GGML_OP_ROPE_BACK);
const bool is_back = op->op == GGML_OP_ROPE_BACK;
char base[256];
char name[256];
@@ -1713,13 +1729,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
snprintf(name, 256, "%s_imrope=%d_is_back=%d", base, is_imrope ? 1 : 0, is_back ? 1 : 0);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
ggml_metal_cv_set_bool(cv, is_back, FC_ROPE + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+1
View File
@@ -115,6 +115,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
+440 -133
View File
@@ -94,8 +94,63 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
//
// MTLLibrary collection (one library per op-source, compiled separately)
//
// Single source of truth for the per-kind metal libraries. The order here
// defines the enum values and every per-kind table below, so adding a library
// is a one-line change here (plus adding its source to CMakeLists.txt).
// X(suffix, name): name is both the kernels/<name>.metal basename and the
// ggml_metallib_<name>_{start,end} embed-symbol stem.
#define GGML_METAL_LIBS \
X(FA, fa) \
X(MUL_MV, mul_mv) \
X(MUL_MM, mul_mm) \
X(QUANTIZE, quantize) \
X(SOFTMAX, softmax) \
X(NORM, norm) \
X(UNARY, unary) \
X(BINBCAST, binbcast) \
X(REDUCE, reduce) \
X(TRI, tri) \
X(SSM, ssm) \
X(WKV, wkv) \
X(GATED_DELTA_NET, gated_delta_net)\
X(SOLVE_TRI, solve_tri) \
X(ROPE, rope) \
X(CONV, conv) \
X(UPSCALE, upscale) \
X(ARGSORT, argsort) \
X(POOL, pool) \
X(MISC, misc)
enum ggml_metal_lib_kind {
#define X(e, s) GGML_METAL_LIB_##e,
GGML_METAL_LIBS
#undef X
GGML_METAL_LIB_COUNT,
};
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
GGML_METAL_LIBS
#undef X
};
struct ggml_metal_library {
id<MTLLibrary> obj;
// Per-kind compiled libraries. When single_library is true, the whole library
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
// objs[0] and the remaining slots are nil.
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
// Routing table: kernel function name -> objs[] index, populated from each
// compiled library's -[MTLLibrary functionNames]. The actual compiled
// libraries are the single source of truth for which library owns a kernel,
// so adding kernels later requires no manual routing maintenance.
// nil in single_library mode (everything resolves to objs[0]).
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
@@ -103,160 +158,376 @@ struct ggml_metal_library {
NSLock * lock;
};
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLLibrary> library = nil;
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
// Build the fn_to_lib routing table by querying each compiled library's public
// function names. Call once after all per-kind libraries have been compiled.
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
@autoreleasepool {
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
for (NSString * fname in [lib->objs[kind] functionNames]) {
index[fname] = @(kind);
}
}
lib->fn_to_lib = index;
}
}
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
//
// TODO: move to a function
{
const int64_t t_start = ggml_time_us();
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
NSScanner * scanner = [NSScanner scannerWithString:line];
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
NSError * error = nil;
NSString * src = nil;
if (![scanner scanString:@"#" intoString:NULL] ||
![scanner scanString:@"include" intoString:NULL] ||
![scanner scanString:@"\"" intoString:NULL]) {
return false;
}
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
NSString * name = nil;
if (![scanner scanUpToString:@"\"" intoString:&name]) {
return false;
}
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
if (include_name) {
*include_name = name;
}
return true;
}
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
// Recursively inline `#include "name"` directives. System includes (<...>),
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
// guards against double-inclusion.
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
NSArray<NSString *> * search_paths,
NSMutableSet<NSString *> * seen, NSError ** error) {
NSString * key = [path stringByStandardizingPath];
if ([seen containsObject:key]) {
return true;
}
[seen addObject:key];
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
if (!src) {
return false;
}
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSFileManager * fm = [NSFileManager defaultManager];
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
if ([trimmed isEqualToString:@"#pragma once"]) {
continue;
}
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
NSString * include_name = nil;
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
NSString * resolved = nil;
for (NSString * dir in search_paths) {
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
if ([fm isReadableFileAtPath:candidate]) {
resolved = candidate;
break;
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
if (!resolved) {
if (error) {
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
userInfo:@{NSLocalizedDescriptionKey: msg}];
}
return false;
}
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
return false;
}
continue;
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
[dst appendString:line];
[dst appendString:@"\n"];
}
library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
return true;
}
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
NSArray<NSString *> * search_paths = @[
path_kernels,
path_base,
[path_base stringByDeletingLastPathComponent],
];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
NSMutableString * src = [[NSMutableString alloc] init];
NSMutableSet<NSString *> * seen = [NSMutableSet set];
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
[src release];
return nil;
}
return src;
}
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
// source for a kind (the helper takes ownership and releases it), or nil with
// *err set on failure. On success the objs[] slots are populated and the routing
// index is built; on any failure every error is logged and false is returned
// (the caller is responsible for freeing `res`).
static bool ggml_metal_library_compile_all(
ggml_metal_library_t res,
id<MTLDevice> device,
NSDictionary * prep,
NSString * (^source_for_kind)(int kind, NSError ** err),
const char * origin) {
const int64_t t_start = ggml_time_us();
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
__block atomic_bool any_failure = false;
dispatch_group_t group = dispatch_group_create();
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
dispatch_group_async(group, queue, ^{
const int64_t t0 = ggml_time_us();
NSError * error = nil;
NSString * src = source_for_kind(kind, &error);
if (!src) {
err_per_lib[kind] = [error retain];
atomic_store(&any_failure, true);
return;
}
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
id<MTLLibrary> lib = nil;
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
#endif
if (!library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
//[options setFastMathEnabled:false];
lib = [device newLibraryWithSource:src options:options error:&error];
library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
#if !__has_feature(objc_arc)
[options release];
#endif
// retain the error before the autorelease pool drains it
if (!lib) {
err_per_lib[kind] = [error retain];
}
}
[src release];
t_per_lib[kind] = ggml_time_us() - t0;
if (!lib) {
atomic_store(&any_failure, true);
return;
}
res->objs[kind] = lib;
});
}
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
dispatch_release(group);
const bool ok = !atomic_load(&any_failure);
if (ok) {
const int64_t t_total = ggml_time_us() - t_start;
int64_t t_max = 0;
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
}
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
ggml_metal_library_build_index(res);
} else {
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (err_per_lib[kind]) {
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
[err_per_lib[kind] release];
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
}
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
free(err_per_lib);
free(t_per_lib);
res->obj = library;
return ok;
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
GGML_METAL_LIBS
#undef X
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
GGML_METAL_LIBS
#undef X
};
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
GGML_METAL_LIBS
#undef X
};
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
(void) err;
return [[NSString alloc] initWithBytes:lib_start[kind]
length:(lib_end[kind] - lib_start[kind])
encoding:NSUTF8StringEncoding];
}, "embedded data");
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
const int64_t t_start = ggml_time_us();
NSError * error = nil;
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
}
if (path_lib != nil) {
// pre-compiled library found: a single combined default.metallib
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
res->single_library = true;
if (!res->objs[0]) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
ggml_metal_library_free(res);
return NULL;
}
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
return res;
}
// no pre-compiled metallib: fall back to compiling each kernel source separately
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (path_resource) {
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
}
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
NSString * path_source = nil;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:rel];
} else {
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
path_source = [bundle pathForResource:stem ofType:@"metal"];
}
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
path_source = rel;
}
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
path_per_kind[kind] = [path_source retain];
}
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
}, "source");
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
[path_per_kind[kind] release];
}
free(path_per_kind);
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#endif
}
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
@@ -318,10 +589,11 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
return NULL;
}
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
res->objs[0] = library;
res->single_library = true;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -331,8 +603,14 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
return;
}
if (lib->obj) {
[lib->obj release];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (lib->objs[kind]) {
[lib->objs[kind] release];
}
}
if (lib->fn_to_lib) {
[lib->fn_to_lib release];
}
ggml_metal_pipelines_free(lib->pipelines);
@@ -393,11 +671,28 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
// route to the library that actually defines this kernel; fn_to_lib is
// built from -[MTLLibrary functionNames] so it's always in sync
int lib_idx = 0;
if (!lib->single_library) {
NSNumber * idx = lib->fn_to_lib[base_func];
if (!idx) {
[lib->lock unlock];
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
return res;
}
lib_idx = [idx intValue];
}
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
id<MTLFunction> mtl_function;
if (!cv) {
mtl_function = [lib->obj newFunctionWithName:base_func];
mtl_function = [mtl_lib newFunctionWithName:base_func];
} else {
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
[lib->lock unlock];
@@ -1123,13 +1418,24 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_CONCAT:
{
// kernel_concat copies one float-sized value per element.
// Other scalar types need a type-generic copy kernel first.
const enum ggml_type src0_type = op->src[0]->type;
const enum ggml_type src1_type = op->src[1]->type;
return src0_type == src1_type &&
src0_type == op->type &&
(src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32);
if (src0_type != src1_type || src0_type != op->type) {
return false;
}
switch (src0_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
return true;
case GGML_TYPE_BF16:
return has_bfloat;
default:
return false;
}
}
case GGML_OP_ADD:
case GGML_OP_SUB:
@@ -1173,6 +1479,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+2 -1
View File
@@ -375,6 +375,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_norm(ctx, idx);
} break;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
n_fuse = ggml_metal_op_rope(ctx, idx);
} break;
@@ -556,7 +557,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
/*.dim =*/ dim,
};
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
auto pipeline = ggml_metal_library_get_pipeline_concat(lib, op->type);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
File diff suppressed because it is too large Load Diff
+232
View File
@@ -0,0 +1,232 @@
#include "common.h"
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
if (total == 0) {
return;
}
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
+226
View File
@@ -0,0 +1,226 @@
#include "common.h"
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
dst_ptr[i0] = res;
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
const int i03 = i3%args.ne03;
const int i02 = i2%args.ne02;
const int i01 = i1%args.ne01;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i00 = i0%args.ne00;
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
#endif
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+126
View File
@@ -0,0 +1,126 @@
#pragma once
#include "ggml-metal-impl.h"
#include <metal_stdlib>
#ifdef GGML_METAL_HAS_TENSOR
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
#endif
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
//
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
#undef GGML_METAL_HAS_BF16
#endif
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
+485
View File
@@ -0,0 +1,485 @@
#include "common.h"
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: optimize
typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;
const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;
float acc = 0.0f;
const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;
int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}
int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}
if (ky_start < ky_end && kx_start < kx_end) {
const uint64_t src_base_n = (uint64_t) n * args.nb13;
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
for (int32_t ic = 0; ic < args.IC; ++ic) {
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
const float x = *(device const float *)(src + src_offs);
const float w = (float) (*(device const TK *)(weights + w_offs));
acc += x * w;
}
}
}
}
const uint64_t dst_offs =
(uint64_t) n * args.nb3 +
(uint64_t) oc * args.nb2 +
(uint64_t) oh * args.nb1 +
(uint64_t) ow * args.nb0;
*(device float *)(dst + dst_offs) = acc;
}
}
template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_1d(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const T * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
// For output position j on the time axis, only input positions
// i such that i*s0 <= j < i*s0 + K
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
// (typically 2 for stride==K/2 transposed convs).
const int32_t j = tgpig[0];
const int32_t s0 = args.s0;
const int32_t K = args.K;
const int32_t IL = args.IL;
int32_t i_min;
{
int32_t a = j - K + 1;
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
}
int32_t i_max = j / s0;
if (i_max > IL - 1) i_max = IL - 1;
float v = 0.0f;
if (i_min <= i_max) {
for (int64_t c = 0; c < args.IC; c++) {
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
const int32_t input_offset = c * IL;
for (int32_t i = i_min; i <= i_max; i++) {
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
}
}
}
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
kernel void kernel_conv_transpose_1d<float>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
kernel void kernel_conv_transpose_1d<half>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const half * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
+686
View File
@@ -0,0 +1,686 @@
#pragma once
#include "common.h"
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)
__embed_ggml-common.h__
#else
#include "ggml-common.h"
#endif
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#if defined(GGML_METAL_HAS_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + md;
reg[2*ii + 1] = d * x1 + md;
}
}
template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + m;
reg[2*ii + 1] = d * x1 + m;
}
}
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) {
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
const uint8_t shr = il >= 4 ? 4 : 0;
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
float dl, ml;
uint8_t sc = xb->scales[il];
q = q + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
q = q + 32 * (il/8) + 16 * (il&1);
h = h + 16 * (il&1);
uint8_t m = 1 << (il/2);
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uint8_t * qh = xb->qh;
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint16_t * ql = (device const uint16_t *)xb->ql;
device const uint16_t * qh = (device const uint16_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
qh = qh + 16*(il/8) + 8*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
const float ml = d_all * sc * 32.f;
const float dl0 = d_all * sc;
const float dl1 = dl0 / 256.f;
const float dl2 = dl0 / (256.f * 256.f);
const float dl3 = dl0 / (256.f * 256.f * 256.f);
const uint8_t shr_h = il>2 ? 2 : 0;
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
const uint8_t shr_l = il>1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * q3 = xb->qs + 8*ib32;
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4>
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * signs = qs + QK_K/8;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
}
}
template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
const float d = xb->d;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh;
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
const uint16_t h = qh[ib32] >> 6*il;
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
reg[1][i] = dl * (grid1[i] >> 4) + ml;
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
reg[3][i] = dl * (grid2[i] >> 4) + ml;
}
}
template <typename type4x4>
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
template <typename type4>
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
reg[0] = d * kvalues_iq4nl_f[q8[0]];
reg[1] = d * kvalues_iq4nl_f[q8[1]];
reg[2] = d * kvalues_iq4nl_f[q8[2]];
reg[3] = d * kvalues_iq4nl_f[q8[3]];
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,250 @@
#include "common.h"
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
+347
View File
@@ -0,0 +1,347 @@
#include "common.h"
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
device int32_t * dst_i32 = (device int32_t *) dst;
threadgroup float * shared_maxval = (threadgroup float *) shmem;
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst_i32[tgpig] = arg_val_reduced;
return;
}
dst_i32[tgpig] = arg_val;
}
kernel void kernel_diag_f32(
constant ggml_metal_kargs_diag & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.p0) {
dst_ptr[i0] = src0_ptr[args.p0 - i0];
} else if (i0 < args.ne0 - args.p1) {
dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
constant ggml_metal_kargs_arange & args,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
constant ggml_metal_kargs_timestep_embedding & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*args.nb1);
int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[2 * half_] = 0.f;
}
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
+838
View File
@@ -0,0 +1,838 @@
#include "common.h"
#include "dequantize.h"
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
template<
typename SA, typename SA_4x4, typename SA_8x8,
typename SB, typename SB_2x4, typename SB_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * srcA,
device const char * srcB,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig [[threadgroup_position_in_grid]],
ushort tiitg [[thread_index_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
(void) sgitg;
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
const int K = args.ne00;
const int M = args.ne0;
const int N = args.ne1;
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
// Tile offsets in output matrix
const int ra = tgpig.y * NRA;
const int rb = tgpig.x * NRB;
// Threadgroup memory for dequantized A tile only
threadgroup SA * sa = (threadgroup SA *)(shmem);
// Work-item count for A loading
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
// tA wraps threadgroup memory
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
// tB wraps device memory directly
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
const int strideB = args.nb11 / sizeof(T1);
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
// Configure matmul operation
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
// Accumulate partial results over K dimension
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
// === PHASE 1: Dequantization of A into threadgroup memory ===
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
const int row = work / N_MM_NK;
const int k_chunk = work % N_MM_NK;
const int k_pos = loop_k + k_chunk * 16;
const short k_base = k_chunk * 16;
// Bounds check: skip device read if row is out of matrix bounds
if (ra + row < M) {
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
// Mirrors the legacy kernel's existing guard.
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
}
} else {
const int block_idx = k_pos / (16 * nl);
const short il = (k_pos / 16) % nl;
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
SA_4x4 temp_a;
dequantize_func(row_ptr + block_idx, il, temp_a);
FOR_UNROLL (short i = 0; i < 16; i++) {
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
}
}
} else {
// Zero-pad rows beyond matrix bounds
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === PHASE 2: Tensor matmul ===
auto mA = tA.slice(0, 0);
auto mB = tB.slice(loop_k, rb);
mm.run(mB, mA, cT);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store result tile to output matrix (with batch offset)
// cT.store handles bounds checking via tD's extents (M, N)
device float * dstBatch = (device float *)dst + im * N * M;
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
cT.store(tD.slice(ra, rb));
}
#else
template<
typename S0, typename S0_4x4, typename S0_8x8,
typename S1, typename S1_2x4, typename S1_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z;
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 + lr1)
+ args.nb10*iy);
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
}
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(r0 + 32*(sgitg & 1)) + \
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < nr1; j += NR1) {
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*NR0);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < nr0/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < nr0; i++) {
*(D + i) = *(C + i);
}
}
}
}
}
#endif // GGML_METAL_HAS_TENSOR
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src2,
device char * htpe,
device char * hids,
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z; // expert
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int32_t neh1 = tpe_u32[im];
if (r1 >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int id = ids_i32[im*args.ne21 + r1 + lr1];
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*iy);
#ifndef GGML_METAL_HAS_TENSOR
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
#else
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<4>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
#endif
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
#ifndef GGML_METAL_HAS_TENSOR
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
#else
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
}
#endif
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifndef GGML_METAL_HAS_TENSOR
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
#else
auto sA = tA.slice(0, 0);
auto sB = tB.slice(0, 0);
mm.run(sB, sA, cT);
#endif
}
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifdef GGML_METAL_HAS_TENSOR
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
cT.store(tC);
#else
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < nr1; j += 4) {
const int id = ids_i32[im*args.ne21 + r1 + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < nr0/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
i = (4*(nr0/4)) + tiisg;
for (; i < nr0; i += 32) {
*(D + i) = *(C + i);
}
}
}
//
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
File diff suppressed because it is too large Load Diff
+308
View File
@@ -0,0 +1,308 @@
#include "common.h"
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t ne = args.ne00*args.ne01*args.ne02;
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
int start = tgpig * gs;
int end = start + gs;
start += tpitg;
if (end >= ne) {
end = ne;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += ntg) {
tmp += src0[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float mean = tmp / gs;
tmp = 0.0f;
for (int j = start; j < end; j += ntg) {
float xi = src0[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float variance = tmp / gs;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
}
+148
View File
@@ -0,0 +1,148 @@
#include "common.h"
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * args.IW + j]);
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
+213
View File
@@ -0,0 +1,213 @@
#pragma once
#include "common.h"
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
#pragma METAL fp math_mode(safe)
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
#pragma METAL fp math_mode(safe)
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst.qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
}
+389
View File
@@ -0,0 +1,389 @@
#include "common.h"
#include "dequantize.h"
#include "quantize.h"
template<typename T0, typename T1>
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
break;
}
}
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
template<typename T>
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
if (i1 >= args.ne1) {
return;
}
int o[4] = {0, 0, 0, 0};
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
device const T * x;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
typedef decltype(kernel_concat<float>) kernel_concat_t;
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
#endif
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
break;
}
}
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
quantize_func(src_row + 32*ind, dst_row[ind]);
}
}
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
dst_row[ind] = (T) src_row[ind];
}
}
//
// get rows
//
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// set rows
//
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
+228
View File
@@ -0,0 +1,228 @@
#include "common.h"
kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
device const float * src0,
device float * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (args.np == 0) {
return;
}
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32;
float sumf = 0;
for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_sum_rows_op
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) {
shmem_t[tiisg] = 0.0f;
}
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_t[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_t[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
}
#undef FC_OP
}
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T>
kernel void kernel_cumsum_blk(
constant ggml_metal_kargs_cumsum_blk & args,
device const char * src0,
device char * tmp,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 +
args.nb01*i01 +
args.nb02*i02 +
args.nb03*i03);
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
float v = 0.0f;
if (i00 + tpitg.x < args.ne00) {
v = src0_row[i00 + tpitg.x];
}
float s = simd_prefix_inclusive_sum(v);
if (tiisg == N_SIMDWIDTH - 1) {
shmem_f32[sgitg] = s;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
s += shmem_f32[sgitg];
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] = s;
}
if (args.outb && tpitg.x == ntg.x - 1) {
device float * tmp_row = (device float *) tmp +
args.net0*i01 +
args.net0*args.net1*i02 +
args.net0*args.net1*args.net2*i03;
tmp_row[ib] = s;
}
}
typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
template<typename T>
kernel void kernel_cumsum_add(
constant ggml_metal_kargs_cumsum_add & args,
device const char * tmp,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ib = tgpig[0]/args.ne01;
if (ib == 0) {
return;
}
const int i00 = ib*ntg.x;
const int i01 = tgpig[0]%args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * tmp_row = (device const float *) (tmp +
args.nbt1*i01 +
args.nbt2*i02 +
args.nbt3*i03);
device float * dst_row = (device float *) dst +
args.ne00*i01 +
args.ne00*args.ne01*i02 +
args.ne00*args.ne01*args.ne02*i03;
if (i00 + tpitg.x < args.ne00) {
dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
}
}
typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
+318
View File
@@ -0,0 +1,318 @@
#include "common.h"
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
constant bool FC_rope_is_back [[function_constant(FC_ROPE + 1)]];
static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
*cos_theta = cos(theta) * mscale;
*sin_theta = sin(theta) * mscale;
if (FC_rope_is_back) {
*sin_theta *= -1.0f;
}
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
}
static void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}
template<typename T>
kernel void kernel_rope_norm(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_neox(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (FC_rope_is_imrope) {
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
theta_base = (float) pos[i2 + args.ne02 * 2];
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
theta_base = (float) pos[i2 + args.ne02 * 0];
} else { // e
theta_base = (float) pos[i2 + args.ne02 * 3];
}
} else {
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02 * 1];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
+223
View File
@@ -0,0 +1,223 @@
#include "common.h"
template<typename T>
kernel void kernel_soft_max(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
pdst[i00] *= inv_sum;
}
}
template<typename T>
kernel void kernel_soft_max_4(
constant ggml_metal_kargs_soft_max & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x;
const int32_t i13 = i03%args.ne13;
const int32_t i12 = i02%args.ne12;
const int32_t i11 = i01;
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
if (args.max_bias > 0.0f) {
const int32_t h = i02;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
}
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
float max_val = simd_max(lmax);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (tptg.x > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
if (psrc2) {
sum += exp(psrc2[i02] - max_val);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
pdst4[i00] *= inv_sum;
}
}
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
@@ -0,0 +1,75 @@
#include "common.h"
constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
kernel void kernel_solve_tri_f32(
constant ggml_metal_kargs_solve_tri & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
ushort3 tgpig[[threadgroup_position_in_grid]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const short NSG = FC_solve_tri_nsg;
const short N = FC_solve_tri_n;
const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW);
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg;
threadgroup float * sh0 = (threadgroup float *) shmem;
device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
for (short rr = 0; rr < N; rr += NSG) {
threadgroup_barrier(mem_flags::mem_threadgroup);
{
threadgroup float * sh0_cur = sh0 + sgitg*NP;
for (short t = 0; t*NW < N; ++t) {
const short idx = t*NW + tiisg;
sh0_cur[idx] = src0_ptr[idx];
}
src0_ptr += NSG*N;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (i01 >= args.ne10) {
continue;
}
for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
const short r = rr + ir;
threadgroup float * sh0_cur = sh0 + ir*NP;
float sum = 0.0f;
for (short t = 0; t*NW < r; ++t) {
const short idx = t*NW + tiisg;
sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
}
sum = simd_sum(sum);
if (tiisg == 0) {
const float diag = sh0_cur[r];
dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
}
}
}
}
+279
View File
@@ -0,0 +1,279 @@
#include "common.h"
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
kernel void kernel_ssm_conv_f32_f32(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
// Batched version: each threadgroup processes multiple tokens for better efficiency
// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
kernel void kernel_ssm_conv_f32_f32_batched(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
x[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_batched_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// tgpig.x = row index (ir)
// tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
// tgpig.z = sequence index (i3)
// tpitg.x = thread within batch (0..BATCH_SIZE-1)
const short BATCH_SIZE = FC_ssm_conv_bs;
const int64_t ir = tgpig.x;
const int64_t i2_base = tgpig.y * BATCH_SIZE;
const int64_t i3 = tgpig.z;
const int64_t i2_off = tpitg.x;
const int64_t i2 = i2_base + i2_off;
const int64_t nc = args.ne10; // conv kernel size (typically 4)
const int64_t n_t = args.ne1; // number of tokens
// Bounds check for partial batches at the end
if (i2 >= n_t) {
return;
}
// Load conv weights (shared across all tokens for this row)
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
// Load source for this specific token
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
// Output location for this token
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// Optimized version: reduces redundant memory loads by having one thread load shared values
kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
// Shared memory layout:
// [0..sgptg*NW-1]: partial sums for reduction (existing)
// [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
// [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
threadgroup float * shared_sums = shared;
threadgroup float * shared_x_dt = shared + sgptg * NW;
threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
shared_sums[tpitg.x] = 0.0f;
const int32_t i0 = tpitg.x;
const int32_t i1 = tgpig.x;
const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int32_t nc = args.d_state;
const int32_t nr = args.d_inner;
const int32_t nh = args.n_head;
const int32_t ng = args.n_group;
const int32_t n_t = args.n_seq_tokens;
const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = 0.0f;
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
const float A0 = A[i0%args.ne30];
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pre-compute x_dt and dA for this batch of tokens
// Only first sgptg threads do the loads and expensive math
if (i0 < sgptg && i2 + i0 < n_t) {
// ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
device const float * x_t = x + i0 * args.ns12;
device const float * dt_t = dt + i0 * args.ns21;
const float dt0 = dt_t[0];
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
shared_x_dt[i0] = x_t[0] * dtsp;
shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
const float x_dt = shared_x_dt[t];
const float dA = exp(shared_dA[t] * A0);
s = (s0 * dA) + (B[i0] * x_dt);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) {
shared_sums[t*NW + sgitg] = sumf;
}
// recurse
s0 = s;
B += args.ns42;
C += args.ns52;
}
// Advance pointers for next batch
x += sgptg * args.ns12;
dt += sgptg * args.ns21;
threadgroup_barrier(mem_flags::mem_threadgroup);
const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
if (tiisg == 0 && i2 + sgitg < n_t) {
y[sgitg*nh*nr] = sumf;
}
y += sgptg*nh*nr;
}
s_buff[i] = s;
}
+69
View File
@@ -0,0 +1,69 @@
#include "common.h"
template<uint32_t ttype>
bool _ggml_vec_tri_cmp(const int i, const int r);
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
return i < r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
return i <= r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
return i > r;
}
template<>
bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
return i >= r;
}
template<typename T, int ttype>
kernel void kernel_tri(
constant ggml_metal_kargs_tri & args,
device const char * src0,
device const char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
// Each thread is a single element of the row if ne00 < max threads per
// threadgroup, so this will loop once for each index that this thread is
// responsible for
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
// Use the comparison as a mask for branchless
dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
}
}
typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
#endif
+360
View File
@@ -0,0 +1,360 @@
#include "common.h"
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
template <typename T0, typename T, typename TC>
kernel void kernel_unary_impl(
constant ggml_metal_kargs_unary & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_unary_op
#define FC_CNT FC_unary_cnt
device const T0 * src0_ptr;
device T * dst_ptr;
int i0;
if (FC_CNT) {
i0 = tgpig.x;
src0_ptr = (device const T0 *) (src0);
dst_ptr = (device T *) (dst);
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int k0 = tgpig.x/args.ne01;
const int i01 = tgpig.x - k0*args.ne01;
i0 = k0*ntg.x + tpitg.x;
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
}
{
//threadgroup_barrier(mem_flags::mem_none);
if (!FC_CNT) {
if (i0 >= args.ne0) {
return;
}
}
const TC x = (TC) src0_ptr[i0];
if (FC_OP == OP_UNARY_NUM_SCALE) {
dst_ptr[i0] = (T) (args.scale * x + args.bias);
}
if (FC_OP == OP_UNARY_NUM_FILL) {
dst_ptr[i0] = (T) args.val;
}
if (FC_OP == OP_UNARY_NUM_CLAMP) {
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
}
if (FC_OP == OP_UNARY_NUM_SQR) {
dst_ptr[i0] = (T) (x * x);
}
if (FC_OP == OP_UNARY_NUM_SQRT) {
dst_ptr[i0] = (T) sqrt(x);
}
if (FC_OP == OP_UNARY_NUM_SIN) {
dst_ptr[i0] = (T) sin(x);
}
if (FC_OP == OP_UNARY_NUM_COS) {
dst_ptr[i0] = (T) cos(x);
}
if (FC_OP == OP_UNARY_NUM_LOG) {
dst_ptr[i0] = (T) log(x);
}
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
}
if (FC_OP == OP_UNARY_NUM_TANH) {
dst_ptr[i0] = (T) precise::tanh(x);
}
if (FC_OP == OP_UNARY_NUM_RELU) {
dst_ptr[i0] = (T) fmax(0, x);
}
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_GELU) {
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
}
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
}
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
}
if (FC_OP == OP_UNARY_NUM_SILU) {
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
}
if (FC_OP == OP_UNARY_NUM_ELU) {
dst_ptr[i0] = (T) elu_approx(x);
}
if (FC_OP == OP_UNARY_NUM_NEG) {
dst_ptr[i0] = (T) -x;
}
if (FC_OP == OP_UNARY_NUM_ABS) {
dst_ptr[i0] = (T) fabs(x);
}
if (FC_OP == OP_UNARY_NUM_SGN) {
dst_ptr[i0] = T(x > 0) - T(x < 0);
}
if (FC_OP == OP_UNARY_NUM_STEP) {
dst_ptr[i0] = T(x > 0);
}
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
}
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
}
if (FC_OP == OP_UNARY_NUM_EXP) {
dst_ptr[i0] = (T) exp(x);
}
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
}
if (FC_OP == OP_UNARY_NUM_EXPM1) {
// TODO: precise implementation
dst_ptr[i0] = (T) (exp(x) - 1);
}
if (FC_OP == OP_UNARY_NUM_FLOOR) {
dst_ptr[i0] = (T) floor(x);
}
if (FC_OP == OP_UNARY_NUM_CEIL) {
dst_ptr[i0] = (T) ceil(x);
}
if (FC_OP == OP_UNARY_NUM_ROUND) {
dst_ptr[i0] = (T) round(x);
}
if (FC_OP == OP_UNARY_NUM_TRUNC) {
dst_ptr[i0] = (T) trunc(x);
}
if (FC_OP == OP_UNARY_NUM_XIELU) {
const TC xi = x;
const TC gate = TC(xi > TC(0.0f));
const TC clamped = fmin(xi, TC(args.val));
const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi;
const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi;
dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg);
}
}
#undef FC_OP
#undef FC_CNT
}
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
template<typename T>
kernel void kernel_reglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
dst_row[i0] = (T)(x0*x1*(x0 > 0.0f));
}
}
typedef decltype(kernel_reglu<float>) kernel_reglu_t;
template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>;
template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>;
template<typename T>
kernel void kernel_geglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
dst_row[i0] = (T)(gelu*x1);
}
}
typedef decltype(kernel_geglu<float>) kernel_geglu_t;
template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>;
template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>;
template<typename T>
kernel void kernel_swiglu(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float silu = x0 / (1.0f + exp(-x0));
dst_row[i0] = (T)(silu*x1);
}
}
typedef decltype(kernel_swiglu<float>) kernel_swiglu_t;
template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>;
template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>;
template<typename T>
kernel void kernel_swiglu_oai(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
float x0 = src0_row[i0];
float x1 = src1_row[i0];
x0 = min(x0, args.limit);
x1 = max(min(x1, args.limit), -args.limit);
float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
out_glu = out_glu * (1.0f + x1);
dst_row[i0] = (T)out_glu;
}
}
typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t;
template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>;
template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>;
template<typename T>
kernel void kernel_geglu_erf(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = (T)(gelu_erf*x1);
}
}
typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t;
template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>;
template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>;
template<typename T>
kernel void kernel_geglu_quick(
constant ggml_metal_kargs_glu & args,
device const char * src0,
device const char * src1,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = (T)(gelu_quick*x1);
}
}
typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t;
template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>;
template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>;
+179
View File
@@ -0,0 +1,179 @@
#include "common.h"
constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
kernel void kernel_upscale_nearest_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3/args.sf3;
const int64_t i02 = i2/args.sf2;
const int64_t i01 = i1/args.sf1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int64_t i00 = i0/args.sf0;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_ptr[0] = src0_ptr[0];
}
}
static inline float bilinear_tri(float x) {
return MAX(0.0f, 1.0f - fabs(x));
}
kernel void kernel_upscale_bilinear_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
src0 += i03*args.nb03 + i02*args.nb02;
device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (FC_upscale_aa) {
const float support0 = MAX(1.0f, 1.0f / args.sf0);
const float invscale0 = 1.0f / support0;
const float support1 = MAX(1.0f, 1.0f / args.sf1);
const float invscale1 = 1.0f / support1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs));
int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs));
float sum = 0.0f;
float wsum = 0.0f;
for (int64_t sy = y_min; sy < y_max; ++sy) {
const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
}
const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
dst_ptr[i0] = v;
}
} else {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00);
device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00);
device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
const float v =
(*src00) * (1.0f - fd0) * (1.0f - fd1) +
(*src10) * fd0 * (1.0f - fd1) +
(*src01) * (1.0f - fd0) * fd1 +
(*src11) * fd0 * fd1;
dst_ptr[i0] = v;
}
}
}
static inline float bicubic_weight1(float x) {
const float a = -0.75f;
return ((a + 2) * x - (a + 3)) * x * x + 1;
}
static inline float bicubic_weight2(float x) {
const float a = -0.75f;
return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
}
kernel void kernel_upscale_bicubic_f32(
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 / args.sf3;
const int64_t i02 = i2 / args.sf2;
const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
const int64_t i01 = (int64_t)floor(f01);
const float fd1 = f01 - (float)i01;
const float w_y0 = bicubic_weight2(fd1 + 1.0f);
const float w_y1 = bicubic_weight1(fd1);
const float w_y2 = bicubic_weight1(1.0f - fd1);
const float w_y3 = bicubic_weight2(2.0f - fd1);
const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
const int64_t i00 = (int64_t)floor(f00);
const float fd0 = f00 - (float)i00;
const float w_x0 = bicubic_weight2(fd0 + 1.0f);
const float w_x1 = bicubic_weight1(fd0);
const float w_x2 = bicubic_weight1(1.0f - fd0);
const float w_x3 = bicubic_weight2(2.0f - fd0);
float sum = 0.0f;
for (int dy = -1; dy <= 2; ++dy) {
const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
for (int dx = -1; dx <= 2; ++dx) {
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
dst_ptr[i0] = sum;
}
}
+179
View File
@@ -0,0 +1,179 @@
#include "common.h"
kernel void kernel_rwkv_wkv6_f32(
device const float * k,
device const float * v,
device const float * r,
device const float * tf,
device const float * td,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _k[head_size];
threadgroup float _r[head_size];
threadgroup float _tf[head_size];
threadgroup float _td[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
_tf[tid] = tf[head_id * head_size + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_k[tid] = k[t];
_r[tid] = r[t];
_td[tid] = td[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0;
for (uint j = 0; j < head_size; j += 4) {
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
float4 temp = tf_vec * kv + s_vec;
y += dot(r_vec, temp);
s_vec = s_vec * td_vec + kv;
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
}
kernel void kernel_rwkv_wkv7_f32(
device const float * r,
device const float * w,
device const float * k,
device const float * v,
device const float * a,
device const float * b,
device const float * state_in,
device float * dst,
constant uint & B,
constant uint & T,
constant uint & C,
constant uint & H,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint head_size = 64; // TODO: support head_size = 128
const uint batch_id = tgpig.x / H;
const uint head_id = tgpig.x % H;
const uint tid = tpitg.x;
if (batch_id >= B || head_id >= H) {
return;
}
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
threadgroup float _r[head_size];
threadgroup float _w[head_size];
threadgroup float _k[head_size];
threadgroup float _a[head_size];
threadgroup float _b[head_size];
float state[head_size];
for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
threadgroup_barrier(mem_flags::mem_threadgroup);
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
threadgroup_barrier(mem_flags::mem_threadgroup);
const float v_val = v[t];
float y = 0.0, sa = 0.0;
float4 sa_vec(0.0);
for (uint j = 0; j < head_size; j += 4) {
float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
sa_vec += a_vec * s_vec;
}
sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
for (uint j = 0; j < head_size; j += 4) {
float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
float4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(s_vec, r_vec);
state[j] = s_vec[0];
state[j+1] = s_vec[1];
state[j+2] = s_vec[2];
state[j+3] = s_vec[3];
}
dst[t] = y;
}
for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}
+5 -5
View File
@@ -39,8 +39,8 @@ if (WIN32)
set(CMAKE_CXX_COMPILER "icx")
set(CMAKE_CXX_COMPILER_ID "IntelLLVM")
endif()
# Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled)
if(GGML_SYCL_SUPPORT_LEVEL_ZERO)
# Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO_API is enabled)
if(GGML_SYCL_SUPPORT_LEVEL_ZERO_API)
if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH})
set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH})
if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}")
@@ -105,8 +105,8 @@ endif()
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}")
if (GGML_SYCL_SUPPORT_LEVEL_ZERO)
message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO_API ${GGML_SYCL_SUPPORT_LEVEL_ZERO_API}")
if (GGML_SYCL_SUPPORT_LEVEL_ZERO_API)
# Link against Level Zero loader for direct device memory allocation.
# Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging
# in the xe kernel driver during multi-GPU inference.
@@ -114,7 +114,7 @@ if (GGML_SYCL_SUPPORT_LEVEL_ZERO)
find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH)
if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR)
target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB})
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO)
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO_API)
message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}")
message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}")
else()
+1
View File
@@ -17,6 +17,7 @@
#include "common.hpp"
#include "concat.hpp"
#include "conv.hpp"
#include "conv3d.hpp"
#include "convert.hpp"
#include "count-equal.hpp"
#include "cpy.hpp"
+5 -5
View File
@@ -12,7 +12,7 @@
#include "common.hpp"
#include <sycl/backend.hpp>
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
#include <level_zero/ze_api.h>
#endif
@@ -84,9 +84,9 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
return sycl_down_blk_size;
}
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) {
return g_ggml_sycl_enable_level_zero &&
return g_ggml_sycl_use_level_zero_api &&
q.get_device().is_gpu() &&
q.get_backend() == sycl::backend::ext_oneapi_level_zero;
}
@@ -95,7 +95,7 @@ static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) {
// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering
// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference.
void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) {
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
if (ggml_sycl_use_level_zero_device_alloc(q)) {
void *ptr = nullptr;
auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_context());
@@ -127,7 +127,7 @@ void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) {
void ggml_sycl_free_device(void *ptr, sycl::queue &q) {
if (!ptr) return;
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
if (ggml_sycl_use_level_zero_device_alloc(q)) {
auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_context());
zeMemFree(ze_ctx, ptr);
+7 -1
View File
@@ -62,6 +62,7 @@ extern int g_ggml_sycl_debug;
extern int g_ggml_sycl_disable_optimize;
extern int g_ggml_sycl_prioritize_dmmv;
extern int g_ggml_sycl_enable_flash_attention;
extern int g_ggml_sycl_dev2dev_memcpy;
#if defined(__clang__) && __has_builtin(__builtin_expect)
@@ -126,6 +127,11 @@ enum ggml_sycl_backend_gpu_mode {
SYCL_MUL_GPU_MODE
};
enum ggml_sycl_dev2dev_memcpy_mode {
DEV2DEV_MEMCPY_SYCL = 0,
DEV2DEV_MEMCPY_L0 = 1,
};
static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size");
static void crash() {
@@ -318,7 +324,7 @@ struct ggml_tensor_extra_gpu {
optimize_feature optimized_feature;
};
extern int g_ggml_sycl_enable_level_zero;
extern int g_ggml_sycl_use_level_zero_api;
void * ggml_sycl_malloc_device(size_t size, sycl::queue &q);
void ggml_sycl_free_device(void *ptr, sycl::queue &q);
+158
View File
@@ -0,0 +1,158 @@
#include "conv2d-dw.hpp"
struct conv2d_dw_params {
int in_w, in_h;
int out_w, out_h;
int kernel_w, kernel_h;
int stride_x, stride_y;
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
};
struct conv2d_dw_kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};
static inline conv2d_dw_kernel_bounds dw_calculate_kernel_bounds(int out_x, int out_y,
const conv2d_dw_params & p) {
conv2d_dw_kernel_bounds bounds;
bounds.y_min = sycl::max(0, (p.padding_y - out_y * p.stride_y + p.dilation_y - 1) / p.dilation_y);
bounds.y_max = sycl::min(p.kernel_h,
(p.in_h + p.padding_y - out_y * p.stride_y + p.dilation_y - 1) / p.dilation_y);
bounds.x_min = sycl::max(0, (p.padding_x - out_x * p.stride_x + p.dilation_x - 1) / p.dilation_x);
bounds.x_max = sycl::min(p.kernel_w,
(p.in_w + p.padding_x - out_x * p.stride_x + p.dilation_x - 1) / p.dilation_x);
return bounds;
}
static inline int dw_calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
// whcn layout: input/output stored as [N, C, H, W]
struct dw_whcn_layout {
static int input_index(int n, int c, int y, int x, const conv2d_dw_params & p) {
return n * (p.channels * p.in_w * p.in_h) + c * p.in_w * p.in_h + y * p.in_w + x;
}
static int kernel_index(int c, int ky, int kx, const conv2d_dw_params & p) {
return c * p.kernel_h * p.kernel_w + ky * p.kernel_w + kx;
}
static int output_index(int n, int c, int y, int x, const conv2d_dw_params & p) {
return n * (p.channels * p.out_w * p.out_h) + c * p.out_w * p.out_h + y * p.out_w + x;
}
static void unpack_indices(int global_idx, const conv2d_dw_params & p,
int & n, int & c, int & out_y, int & out_x) {
out_x = global_idx % p.out_w;
out_y = (global_idx / p.out_w) % p.out_h;
c = (global_idx / (p.out_w * p.out_h)) % p.channels;
n = global_idx / (p.out_w * p.out_h * p.channels);
}
};
// cwhn layout: input/output stored as [N, H, W, C]
struct dw_cwhn_layout {
static int input_index(int n, int c, int y, int x, const conv2d_dw_params & p) {
return n * (p.channels * p.in_w * p.in_h) + (y * p.in_w + x) * p.channels + c;
}
static int kernel_index(int c, int ky, int kx, const conv2d_dw_params & p) {
return (ky * p.kernel_w + kx) * p.channels + c;
}
static int output_index(int n, int c, int y, int x, const conv2d_dw_params & p) {
return n * (p.channels * p.out_w * p.out_h) + y * (p.out_w * p.channels) + x * p.channels + c;
}
static void unpack_indices(int global_idx, const conv2d_dw_params & p,
int & n, int & c, int & out_y, int & out_x) {
c = global_idx % p.channels;
out_x = (global_idx / p.channels) % p.out_w;
out_y = (global_idx / (p.channels * p.out_w)) % p.out_h;
n = global_idx / (p.channels * p.out_w * p.out_h);
}
};
template <typename Layout>
static void conv2d_dw_kernel(const float * input, const float * kernel, float * output,
const conv2d_dw_params p, const sycl::nd_item<3> & item_ct1) {
const int global_idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
const int total_elements = p.batches * p.channels * p.out_h * p.out_w;
if (global_idx >= total_elements) {
return;
}
int n, c, out_y, out_x;
Layout::unpack_indices(global_idx, p, n, c, out_y, out_x);
float acc = 0.0f;
const conv2d_dw_kernel_bounds bounds = dw_calculate_kernel_bounds(out_x, out_y, p);
for (int ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int in_y = dw_calculate_input_coord(out_y, ky, p.stride_y, p.dilation_y, p.padding_y);
for (int kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int in_x = dw_calculate_input_coord(out_x, kx, p.stride_x, p.dilation_x, p.padding_x);
acc += input[Layout::input_index(n, c, in_y, in_x, p)] *
kernel[Layout::kernel_index(c, ky, kx, p)];
}
}
output[Layout::output_index(n, c, out_y, out_x, p)] = acc;
}
template <typename Layout>
static void conv2d_dw_sycl(const float * x_d, const float * w_d, float * y_d,
const conv2d_dw_params p, const queue_ptr & stream) {
const int total = p.batches * p.channels * p.out_h * p.out_w;
const int num_blocks = (total + SYCL_CONV2D_DW_BLOCK_SIZE - 1) / SYCL_CONV2D_DW_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV2D_DW_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
conv2d_dw_kernel<Layout>(x_d, w_d, y_d, p, item_ct1);
});
}
void ggml_sycl_op_conv2d_dw(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * w_d = (const float *) kernel->data;
const float * x_d = (const float *) input->data;
float * y_d = (float *) dst->data;
const int32_t * p = (const int32_t *) dst->op_params;
const int stride_x = p[0];
const int stride_y = p[1];
const int padding_x = p[2];
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int in_w = input->ne[0];
const int in_h = input->ne[1];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int out_w = dst->ne[0];
const int out_h = dst->ne[1];
const int channels = dst->ne[2];
const int batches = dst->ne[3];
const conv2d_dw_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h,
stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches };
const queue_ptr stream = ctx.stream();
if (ggml_is_contiguous(input)) {
conv2d_dw_sycl<dw_whcn_layout>(x_d, w_d, y_d, params, stream);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_sycl<dw_cwhn_layout>(x_d, w_d, y_d, params, stream);
} else {
GGML_ABORT("Unsupported memory layout for conv2d_dw");
}
}
+10
View File
@@ -0,0 +1,10 @@
#ifndef GGML_SYCL_CONV2D_DW_HPP
#define GGML_SYCL_CONV2D_DW_HPP
#include "common.hpp"
#define SYCL_CONV2D_DW_BLOCK_SIZE 256
void ggml_sycl_op_conv2d_dw(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_CONV2D_DW_HPP
+125
View File
@@ -0,0 +1,125 @@
#include "conv2d-transpose.hpp"
#include "convert.hpp"
template <typename kernel_t>
static void conv2d_transpose_kernel(const float * input, const kernel_t * kernel, float * output,
const int in_w, const int in_h,
const int out_w, const int out_h,
const int kernel_w, const int kernel_h,
const int stride,
const int c_in, const int c_out, const int batches,
const sycl::nd_item<3> & item_ct1) {
const int global_idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
const int total_elements = out_w * out_h * c_out * batches;
if (global_idx >= total_elements) {
return;
}
const int out_x = global_idx % out_w;
const int out_y = (global_idx / out_w) % out_h;
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
const int n_idx = global_idx / (out_w * out_h * c_out);
float acc = 0.0f;
for (int c_in_idx = 0; c_in_idx < c_in; ++c_in_idx) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y - kh;
if (in_y < 0 || in_y % stride) {
continue;
}
in_y /= stride;
if (in_y >= in_h) {
continue;
}
for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x - kw;
if (in_x < 0 || in_x % stride) {
continue;
}
in_x /= stride;
if (in_x >= in_w) {
continue;
}
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + in_w * in_y + in_x;
const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx +
kernel_w * kh + kw;
acc += input[input_idx] * ggml_sycl_cast<float>(kernel[kernel_idx]);
}
}
}
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + out_w * out_y + out_x] = acc;
}
template <typename kernel_t>
static void conv2d_transpose_sycl(const float * input_d, const kernel_t * kernel_d, float * output_d,
const int in_w, const int in_h,
const int out_w, const int out_h,
const int kernel_w, const int kernel_h,
const int stride,
const int c_in, const int c_out, const int batches,
const queue_ptr & stream) {
const int total = out_w * out_h * c_out * batches;
const int num_blocks = (total + SYCL_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / SYCL_CONV2D_TRANSPOSE_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV2D_TRANSPOSE_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
conv2d_transpose_kernel<kernel_t>(input_d, kernel_d, output_d,
in_w, in_h, out_w, out_h, kernel_w, kernel_h,
stride, c_in, c_out, batches, item_ct1);
});
}
// input: (W, H, C_in, N)
// kernel: (W, H, C_out, C_in)
// output: (W, H, C_out, N)
void ggml_sycl_op_conv2d_transpose(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(input));
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));
const float * input_d = (const float *) input->data;
float * output_d = (float *) dst->data;
const void * kernel_d = kernel->data;
const int input_w = input->ne[0];
const int input_h = input->ne[1];
const int channels_in = input->ne[2];
const int batches = input->ne[3];
const int output_w = dst->ne[0];
const int output_h = dst->ne[1];
const int channels_out = kernel->ne[2];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int stride = dst->op_params[0];
GGML_ASSERT(channels_in == kernel->ne[3]);
GGML_ASSERT(stride > 0);
const queue_ptr stream = ctx.stream();
if (kernel->type == GGML_TYPE_F16) {
conv2d_transpose_sycl<sycl::half>(input_d, (const sycl::half *) kernel_d, output_d,
input_w, input_h, output_w, output_h, kernel_w, kernel_h,
stride, channels_in, channels_out, batches, stream);
} else {
conv2d_transpose_sycl<float>(input_d, (const float *) kernel_d, output_d,
input_w, input_h, output_w, output_h, kernel_w, kernel_h,
stride, channels_in, channels_out, batches, stream);
}
}
+10
View File
@@ -0,0 +1,10 @@
#ifndef GGML_SYCL_CONV2D_TRANSPOSE_HPP
#define GGML_SYCL_CONV2D_TRANSPOSE_HPP
#include "common.hpp"
#define SYCL_CONV2D_TRANSPOSE_BLOCK_SIZE 256
void ggml_sycl_op_conv2d_transpose(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_CONV2D_TRANSPOSE_HPP
+150
View File
@@ -0,0 +1,150 @@
#include "conv2d.hpp"
#include "convert.hpp"
struct conv2d_params {
const int64_t IW, IH;
const int64_t OW, OH;
const int64_t KW, KH;
const int64_t ST_X, ST_Y;
const int64_t PD_X, PD_Y;
const int64_t DL_X, DL_Y;
const int64_t IC, OC;
const int64_t B;
const int64_t TOTAL;
};
struct conv2d_kernel_bounds {
int64_t y_min, y_max;
int64_t x_min, x_max;
};
static inline int64_t conv2d_max64(int64_t a, int64_t b) {
return (a > b) ? a : b;
}
static inline int64_t conv2d_min64(int64_t a, int64_t b) {
return (a < b) ? a : b;
}
static inline conv2d_kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv2d_params & P) {
conv2d_kernel_bounds bounds;
bounds.y_min = conv2d_max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.y_max = conv2d_min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
bounds.x_min = conv2d_max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
bounds.x_max = conv2d_min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
return bounds;
}
static inline int calculate_input_coord(int64_t out_coord, int64_t kern_coord, int64_t stride,
int64_t dilation, int64_t padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
// whcn layout helpers (matching ggml tensor memory order)
static inline int64_t whcn_input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv2d_params & P) {
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
}
static inline int64_t whcn_kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv2d_params & P) {
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
}
static inline int64_t whcn_output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv2d_params & P) {
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
}
template <typename T>
static void conv2d_kernel(const float * input, const T * kernel, float * output,
const conv2d_params P, const sycl::nd_item<3> & item_ct1) {
const int64_t global_idx = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (global_idx >= P.TOTAL) {
return;
}
const int64_t out_x = global_idx % P.OW;
const int64_t out_y = (global_idx / P.OW) % P.OH;
const int64_t c_out = (global_idx / (P.OW * P.OH)) % P.OC;
const int64_t n = global_idx / (P.OW * P.OH * P.OC);
float acc = 0.0f;
const conv2d_kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
const float input_val = input[whcn_input_index(n, c_in, in_y, in_x, P)];
const T kernel_val = kernel[whcn_kernel_index(c_out, c_in, ky, kx, P)];
acc += input_val * ggml_sycl_cast<float>(kernel_val);
}
}
}
output[whcn_output_index(n, c_out, out_y, out_x, P)] = acc;
}
template <typename T>
static void conv2d_sycl(const float * X_D, const T * K_D, float * Y_D,
const conv2d_params P, const queue_ptr & stream) {
const int num_blocks = (P.TOTAL + SYCL_CONV2D_BLOCK_SIZE - 1) / SYCL_CONV2D_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV2D_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
conv2d_kernel<T>(X_D, K_D, Y_D, P, item_ct1);
});
}
void ggml_sycl_op_conv2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
const float * K_D = (const float *) kernel->data;
const float * X_D = (const float *) input->data;
float * Y_D = (float *) dst->data;
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
GGML_ASSERT(input->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
// same number of input channels
GGML_ASSERT(input->ne[2] == kernel->ne[2]);
const queue_ptr stream = ctx.stream();
const int32_t * p = (const int32_t *) dst->op_params;
const int ST_X = p[0];
const int ST_Y = p[1];
const int PD_X = p[2];
const int PD_Y = p[3];
const int DL_X = p[4];
const int DL_Y = p[5];
// no cwhn layout support
GGML_ASSERT(p[6] == 0);
const int IW = input->ne[0];
const int IH = input->ne[1];
const int OW = dst->ne[0];
const int OH = dst->ne[1];
const int KW = kernel->ne[0];
const int KH = kernel->ne[1];
const int IC = input->ne[2];
const int OC = kernel->ne[3];
const int B = input->ne[3];
const int64_t total = (int64_t) B * OC * OH * OW;
const conv2d_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
if (kernel->type == GGML_TYPE_F16) {
conv2d_sycl<sycl::half>(X_D, (const sycl::half *) K_D, Y_D, params, stream);
} else {
conv2d_sycl<float>(X_D, K_D, Y_D, params, stream);
}
}
+10
View File
@@ -0,0 +1,10 @@
#ifndef GGML_SYCL_CONV2D_HPP
#define GGML_SYCL_CONV2D_HPP
#include "common.hpp"
#define SYCL_CONV2D_BLOCK_SIZE 256
void ggml_sycl_op_conv2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_CONV2D_HPP
+218
View File
@@ -0,0 +1,218 @@
#include "conv3d.hpp"
static inline int64_t ggml_sycl_conv3d_calc_patch_total(const ggml_tensor * dst, int32_t n) {
return (int64_t) n * dst->ne[0] * dst->ne[1] * dst->ne[2];
}
static inline int64_t ggml_sycl_conv3d_calc_knl_n_total(const ggml_tensor * src0, int32_t c) {
return (int64_t) src0->ne[0] * src0->ne[1] * src0->ne[2] * c;
}
static inline void ggml_sycl_conv3d_write_output(
const ggml_tensor * dst,
const float * src, float * dst_data,
int64_t patch_total, int64_t oc,
int64_t dst_w, int64_t dst_h, int64_t dst_d,
dpct::queue_ptr stream) {
const int64_t dst_nb0 = dst->nb[0];
const int64_t dst_nb1 = dst->nb[1];
const int64_t dst_nb2 = dst->nb[2];
const int64_t dst_nb3 = dst->nb[3];
const int64_t total = patch_total * oc;
const int64_t block_size = 256;
const int64_t num_work_items = ((total + block_size - 1) / block_size) * block_size;
stream->parallel_for(sycl::range<1>(num_work_items), [=](sycl::id<1> id) {
const int64_t i = id[0];
if (i >= total) {
return;
}
const int64_t patch_idx = i / oc;
const int64_t out_ch = i % oc;
const int64_t p_in_batch = patch_idx % (dst_w * dst_h * dst_d);
const int64_t batch_idx = patch_idx / (dst_w * dst_h * dst_d);
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
const int64_t dst_y = (p_in_batch % (dst_w * dst_h)) / dst_w;
const int64_t dst_x = p_in_batch % dst_w;
const int64_t ocn_idx = batch_idx * oc + out_ch;
const int64_t dst_offset = dst_x * dst_nb0 + dst_y * dst_nb1 + dst_z * dst_nb2 + ocn_idx * dst_nb3;
// `src` is a column-major (m x n) GEMM output where m == patch_total, n == oc.
// GEMM stores element (row, col) at index `row + col*m`, so compute index accordingly.
const int64_t src_index = patch_idx + out_ch * patch_total;
const float value = src[src_index];
*(float *)((char *)dst_data + dst_offset) = value;
});
}
void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const int32_t * opts = (const int32_t *) dst->op_params;
const int32_t s0 = opts[0];
const int32_t s1 = opts[1];
const int32_t s2 = opts[2];
const int32_t p0 = opts[3];
const int32_t p1 = opts[4];
const int32_t p2 = opts[5];
const int32_t d0 = opts[6];
const int32_t d1 = opts[7];
const int32_t d2 = opts[8];
const int32_t c = opts[9];
const int32_t n = opts[10];
const int32_t oc = opts[11];
const int64_t knl_w = src0->ne[0];
const int64_t knl_h = src0->ne[1];
const int64_t knl_d = src0->ne[2];
const int64_t patch_total = ggml_sycl_conv3d_calc_patch_total(dst, n);
const int64_t knl_n_total = ggml_sycl_conv3d_calc_knl_n_total(src0, c);
const size_t kernel_type_size = ggml_element_size(src0);
ggml_sycl_pool_alloc<float> gemm_output(ctx.pool());
gemm_output.alloc((size_t) patch_total * oc);
ggml_tensor dst_mat = {};
dst_mat.type = GGML_TYPE_F32;
dst_mat.ne[0] = patch_total;
dst_mat.ne[1] = oc;
dst_mat.ne[2] = 1;
dst_mat.ne[3] = 1;
dst_mat.nb[0] = sizeof(float);
dst_mat.nb[1] = dst_mat.nb[0] * dst_mat.ne[0];
dst_mat.nb[2] = dst_mat.nb[1];
dst_mat.nb[3] = dst_mat.nb[2];
dst_mat.data = gemm_output.get();
dst_mat.buffer = dst->buffer;
dst_mat.extra = dst->extra;
dpct::queue_ptr stream = ctx.stream();
// allocate packed arrays: A_packed (k x m), B_packed (k x n)
ggml_sycl_pool_alloc<float> A_packed_alloc(ctx.pool());
ggml_sycl_pool_alloc<float> B_packed_alloc(ctx.pool());
A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float));
B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float));
float * A_packed = A_packed_alloc.get();
float * B_packed = B_packed_alloc.get();
const int m = (int) patch_total;
const int n_gemm = (int) oc;
const int k = (int) knl_n_total;
// Combined kernel: im2col -> pack A, and pack B simultaneously
const char * src1_base = (const char *) src1->data;
const int64_t src1_nb0 = src1->nb[0];
const int64_t src1_nb1 = src1->nb[1];
const int64_t src1_nb2 = src1->nb[2];
const int64_t src1_nb3 = src1->nb[3];
// Compute correct strides for src0 as (knl_n_total, oc) matrix
const int64_t src0_packed_nb0 = kernel_type_size;
const int64_t src0_packed_nb1 = kernel_type_size * knl_n_total;
const int64_t KW = knl_w;
const int64_t KH = knl_h;
const int64_t KD = knl_d;
const int64_t PW = dst->ne[0];
const int64_t PH = dst->ne[1];
const int64_t PD = dst->ne[2];
// Pack A (with inline im2col): for each (row, col) in k x m matrix
const int64_t A_total = (int64_t)k * m;
const int64_t A_block_size = 256;
const int64_t A_num_work = ((A_total + A_block_size - 1) / A_block_size) * A_block_size;
stream->parallel_for(sycl::range<1>(A_num_work), [=](sycl::id<1> id) {
const int64_t t = id[0];
if (t >= A_total) return;
const int64_t row = t % k;
const int64_t col = t / k;
// Inline im2col for this element
const int64_t k_index = row;
const int64_t patch_idx = col;
const int64_t ic = k_index / (KD * KH * KW);
const int64_t rem = k_index - ic * (KD * KH * KW);
const int64_t kz = rem / (KH * KW);
const int64_t rem2 = rem - kz * (KH * KW);
const int64_t ky = rem2 / KW;
const int64_t kx = rem2 % KW;
const int64_t p_in_batch = patch_idx % (PW * PH * PD);
const int64_t batch_idx = patch_idx / (PW * PH * PD);
const int64_t dst_z = p_in_batch / (PW * PH);
const int64_t dst_y = (p_in_batch % (PW * PH)) / PW;
const int64_t dst_x = p_in_batch % PW;
const int64_t sx = dst_x * s0 + kx * d0 - p0;
const int64_t sy = dst_y * s1 + ky * d1 - p1;
const int64_t sz = dst_z * s2 + kz * d2 - p2;
float val = 0.0f;
if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) {
const int64_t channel_idx = batch_idx * c + ic;
const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3;
val = *(const float *) ptr;
}
A_packed[row + col * (int64_t)k] = val;
});
// Pack B: for each (row, col) in k x n_gemm matrix
const int64_t B_total = (int64_t)k * n_gemm;
const int64_t B_block_size = 256;
const int64_t B_num_work = ((B_total + B_block_size - 1) / B_block_size) * B_block_size;
stream->parallel_for(sycl::range<1>(B_num_work), [=](sycl::id<1> id) {
const int64_t t = id[0];
if (t >= B_total) return;
const int64_t row = t % k;
const int64_t col = t / k;
const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1;
float v;
if (src0->type == GGML_TYPE_F32) {
v = *(const float *) src_ptr;
} else {
v = sycl::vec<sycl::half, 1>(*(const sycl::half *) src_ptr).convert<float, sycl::rounding_mode::automatic>()[0];
}
B_packed[row + col * (int64_t)k] = v;
});
// GEMM: C = A^T * B where A is (k x m), B is (k x n), C is (m x n)
const float alpha = 1.0f;
const float beta = 0.0f;
const int lda = k;
const int ldb = k;
const int ldc = m;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans,
m, n_gemm, k,
dpct::get_value(&alpha, *stream),
(const float *) A_packed, lda,
(const float *) B_packed, ldb,
dpct::get_value(&beta, *stream),
(float *) dst_mat.data, ldc)));
const float * gemm_data = (const float *) dst_mat.data;
float * dst_data = (float *) dst->data;
ggml_sycl_conv3d_write_output(dst, gemm_data, dst_data, patch_total, oc,
dst->ne[0], dst->ne[1], dst->ne[2], stream);
}
+8
View File
@@ -0,0 +1,8 @@
#ifndef GGML_SYCL_CONV3D_HPP
#define GGML_SYCL_CONV3D_HPP
#include "common.hpp"
void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_CONV3D_HPP
+6
View File
@@ -642,6 +642,8 @@ static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct
to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_sycl<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
@@ -724,6 +726,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
switch (type) {
case GGML_TYPE_Q1_0:
return dequantize_block_sycl<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
if (dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
@@ -830,6 +834,8 @@ to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
case GGML_TYPE_BF16:
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
#endif
case GGML_TYPE_Q1_0:
return dequantize_block_nc_sycl<QK1_0, QR1_0, dequantize_q1_0>;
case GGML_TYPE_Q4_0:
return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;
case GGML_TYPE_Q4_1:
+15
View File
@@ -70,6 +70,21 @@ static __dpct_inline__ void dequantize_q4_0_reorder(const void *d_ptr, const int
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q1_0_reorder(const void *d_ptr, const int64_t ib, const void *qs,
const int iqs, dfloat2 &v) {
// Q1_0 reorder layout: scale values followed by quantized bits
const dfloat d = (const dfloat)*((const sycl::half*)d_ptr+ib);
const int bit_index_0 = iqs + 0;
const int bit_index_1 = iqs + 1;
const int bit_0 = (*((const uint8_t *)qs + bit_index_0 / 8) >> (bit_index_0 % 8)) & 1;
const int bit_1 = (*((const uint8_t *)qs + bit_index_1 / 8) >> (bit_index_1 % 8)) & 1;
v.x() = (2 * bit_0 - 1) * d;
v.y() = (2 * bit_1 - 1) * d;
}
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q4_1 * x = (const block_q4_1 *) vx;
+53
View File
@@ -1423,6 +1423,50 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
}
}
static void dequantize_mul_mat_vec_q1_0_sycl_reorder(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK1_0, QR1_0, dequantize_q1_0_reorder>(
vx, y, dst, ncols, nrows, item_ct1);
});
}
}
static void dequantize_mul_mat_vec_q1_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK1_0, QR1_0, dequantize_q1_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
}
}
static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
@@ -1759,6 +1803,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
sycl::half *src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q1_0 ||
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 ||
@@ -1777,6 +1822,14 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
#endif // GGML_SYCL_F16
switch (src0->type) {
case GGML_TYPE_Q1_0:
if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
dequantize_mul_mat_vec_q1_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
} else {
dequantize_mul_mat_vec_q1_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q4_0:
if ((ggml_tensor_extra_gpu*)dst->src[0]->extra &&
((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
+15 -7
View File
@@ -13,14 +13,14 @@
#ifndef GGML_SYCL_DPCT_HELPER_HPP
#define GGML_SYCL_DPCT_HELPER_HPP
#include <cstdlib>
#include <iostream>
#include <map>
#include <sycl/sycl.hpp>
#include <sycl/half_type.hpp>
#include <oneapi/mkl.hpp>
#include <map>
#include "ggml.h"
#if defined(__linux__)
#include <sys/mman.h>
#elif defined(_WIN64)
@@ -43,6 +43,7 @@
#include <windows.h>
#endif
#define DPCT_COMPATIBILITY_TEMP (900)
#if defined(_MSC_VER)
@@ -59,6 +60,13 @@
#define __dpct_noinline__ __attribute__((noinline))
#endif
#define DPCT_UNUSED(x) (void)(x)
inline void _abort(const char * str) {
std::cerr << str << std::endl;
std::abort();
}
inline std::string get_device_type_name(const sycl::device &Device) {
auto DeviceType = Device.get_info<sycl::info::device::device_type>();
switch (DeviceType) {
@@ -1017,7 +1025,7 @@ namespace dpct
if (backend == "opencl:cpu") return 4;
if (backend == "opencl:acc") return 5;
printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
GGML_ABORT("fatal error");
_abort("fatal error");
}
static bool compare_backend(std::string &backend1, std::string &backend2) {
return convert_backend_index(backend1) < convert_backend_index(backend2);
@@ -1426,7 +1434,7 @@ namespace dpct
if (!size)
return sycl::event{};
return q.memcpy(to_ptr, from_ptr, size, dep_events);
GGML_UNUSED(direction);
DPCT_UNUSED(direction);
}
// Get actual copy range and make sure it will not exceed range.
@@ -2092,7 +2100,7 @@ namespace dpct
if (!size)
return sycl::event{};
return q.memcpy(to_ptr, from_ptr, size, dep_events);
GGML_UNUSED(direction);
DPCT_UNUSED(direction);
}
// Get actual copy range and make sure it will not exceed range.
+102 -38
View File
@@ -32,7 +32,7 @@
#include <sycl/sycl.hpp>
#include <sycl/backend.hpp>
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
#include <level_zero/ze_api.h>
#endif
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
@@ -62,6 +62,9 @@
#include "ggml-sycl/repeat_back.hpp"
#include "ggml-sycl/set_rows.hpp"
#include "ggml-sycl/set.hpp"
#include "ggml-sycl/conv2d.hpp"
#include "ggml-sycl/conv2d-dw.hpp"
#include "ggml-sycl/conv2d-transpose.hpp"
#include "ggml-sycl/ssm_conv.hpp"
#include "ggml-sycl/sycl_hw.hpp"
#include "ggml-sycl/ssm_scan.hpp"
@@ -84,8 +87,9 @@ int g_ggml_sycl_enable_vmm = 1;
int g_ggml_sycl_prioritize_dmmv = 0;
int g_ggml_sycl_use_async_mem_op = 0;
int g_ggml_sycl_use_async_mem_op_requested = 1;
int g_ggml_sycl_enable_level_zero = 0;
int g_ggml_sycl_use_level_zero_api = 0;
int g_ggml_sycl_enable_flash_attention = 1;
int g_ggml_sycl_dev2dev_memcpy = DEV2DEV_MEMCPY_SYCL;
int g_ggml_sycl_usm_system = 0;
static ggml_sycl_device_info ggml_sycl_init() {
@@ -153,7 +157,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.ext_oneapi_level_zero = false;
}
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
if (info.ext_oneapi_level_zero && device.is_gpu() && device.default_queue().get_backend() == sycl::backend::ext_oneapi_level_zero) {
ze_device_handle_t ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(device.default_queue().get_device());
ze_device_properties_t props = {};
@@ -168,13 +172,13 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.default_tensor_split[id] /= total_vram;
}
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
// Large buffers can be allocated before ggml_check_sycl() initializes other
// g_ggml_sycl_enable_* globals, so initialize this one as early as we can.
g_ggml_sycl_enable_level_zero =
info.ext_oneapi_level_zero && ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1);
g_ggml_sycl_use_level_zero_api =
info.ext_oneapi_level_zero && ggml_sycl_get_env("GGML_SYCL_USE_LEVEL_ZERO_API", 1);
#else
g_ggml_sycl_enable_level_zero = 0;
g_ggml_sycl_use_level_zero_api = 0;
#endif
return info;
@@ -272,6 +276,11 @@ static void ggml_check_sycl() try {
g_ggml_sycl_enable_vmm = ggml_sycl_get_env("GGML_SYCL_ENABLE_VMM", 1);
g_ggml_sycl_prioritize_dmmv = ggml_sycl_get_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
g_ggml_sycl_dev2dev_memcpy = ggml_sycl_get_env("GGML_SYCL_DEV2DEV_MEMCPY", DEV2DEV_MEMCPY_SYCL);
if (g_ggml_sycl_use_level_zero_api == 0) {
g_ggml_sycl_dev2dev_memcpy = DEV2DEV_MEMCPY_SYCL;
}
#ifdef SYCL_FLASH_ATTN
g_ggml_sycl_enable_flash_attention = ggml_sycl_get_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
#else
@@ -303,10 +312,10 @@ static void ggml_check_sycl() try {
#else
GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
#endif
#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO)
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n");
#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO_API)
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO_API: yes\n");
#else
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n");
GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO_API: no\n");
#endif
#if defined(GGML_SYCL_USE_VMM)
GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n");
@@ -322,10 +331,13 @@ static void ggml_check_sycl() try {
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero);
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
GGML_LOG_INFO(" GGML_SYCL_USE_LEVEL_ZERO_API: %d\n", g_ggml_sycl_use_level_zero_api);
GGML_LOG_INFO(" GGML_SYCL_DEV2DEV_MEMCPY: %d\n", g_ggml_sycl_dev2dev_memcpy);
#else
GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n");
GGML_LOG_INFO(" GGML_SYCL_USE_LEVEL_ZERO_API: Disable Level Zero API usage by compile flag\n");
GGML_LOG_INFO(" GGML_SYCL_DEV2DEV_MEMCPY: %d, enable to SYCL API since missing GGML_SYCL_SUPPORT_LEVEL_ZERO_API\n",
g_ggml_sycl_dev2dev_memcpy);
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
@@ -590,7 +602,7 @@ catch (sycl::exception const &exc) {
std::exit(1);
}
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
static bool ggml_sycl_is_l0_discrete_gpu(int device) {
return ggml_sycl_info().devices[device].l0_discrete_gpu;
}
@@ -598,27 +610,42 @@ static bool ggml_sycl_is_l0_discrete_gpu(int device) {
static void dev2dev_memcpy(int device_dst, sycl::queue &q_dst, int device_src, sycl::queue &q_src, void *ptr_dst,
const void *ptr_src, size_t size) {
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO
// Use Level Zero direct copy for dGPU-to-dGPU transfers.
const bool l0_copy_supported = g_ggml_sycl_enable_level_zero &&
ggml_sycl_is_l0_discrete_gpu(device_dst) && ggml_sycl_is_l0_discrete_gpu(device_src);
if (l0_copy_supported) {
auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_context());
auto ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_device());
ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0,
0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
ze_command_list_handle_t cl;
ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl);
if (r == ZE_RESULT_SUCCESS) {
r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr);
zeCommandListDestroy(cl);
#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO_API
if (g_ggml_sycl_dev2dev_memcpy == DEV2DEV_MEMCPY_L0) {
// Use Level Zero direct copy for dGPU-to-dGPU transfers.
const bool l0_copy_supported =
ggml_sycl_is_l0_discrete_gpu(device_dst) && ggml_sycl_is_l0_discrete_gpu(device_src);
if (g_ggml_sycl_use_level_zero_api && l0_copy_supported) {
auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_context());
auto ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_device());
ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0,
0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
ze_command_list_handle_t cl;
ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl);
if (r == ZE_RESULT_SUCCESS) {
return;
GGML_SYCL_DEBUG("[SYCL] dev2dev memcpy by L0\n");
r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr);
zeCommandListDestroy(cl);
if (r == ZE_RESULT_SUCCESS) {
return;
}
}
}
}
#endif
if (g_ggml_sycl_dev2dev_memcpy == DEV2DEV_MEMCPY_SYCL) {
if (q_dst.get_device().ext_oneapi_can_access_peer(q_src.get_device(),
sycl::ext::oneapi::peer_access::access_supported)) {
GGML_SYCL_DEBUG("[SYCL] dev2dev memcpy by SYCL\n");
SYCL_CHECK(CHECK_TRY_ERROR(q_dst.memcpy(ptr_dst, ptr_src, size).wait()));
return;
}
}
// Host-staged copy
GGML_SYCL_DEBUG("[SYCL] dev2dev memcpy by host forward\n");
char *host_buf = (char *)malloc(size);
q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
@@ -949,6 +976,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_SYC
}
switch(type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
return max_compute_capability >= VER_GEN9 ? 128 : 64;
@@ -3480,6 +3508,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
@@ -3495,6 +3524,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
return true;
@@ -3505,6 +3535,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q3_K:
@@ -3519,6 +3550,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -4572,6 +4604,11 @@ static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * d
ggml_sycl_op_im2col_3d(ctx, dst);
}
static void ggml_sycl_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_conv_3d(ctx, dst);
}
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -4635,9 +4672,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_ARGMAX:
ggml_sycl_argmax(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_sycl_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_sycl_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_3D:
ggml_sycl_conv_3d(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_sycl_op_conv_transpose_1d(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_sycl_op_conv2d_transpose(ctx, dst);
break;
case GGML_OP_REPEAT:
ggml_sycl_repeat(ctx, dst);
break;
@@ -5341,7 +5390,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
return nullptr;
}
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
static bool do_ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_sycl_device_context *sycl_ctx =
(ggml_backend_sycl_device_context *)dev->context;
int device = sycl_ctx->device;
@@ -5355,6 +5404,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
}
return false;
}
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
return true;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_SGN:
@@ -5402,19 +5455,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
struct ggml_tensor * a = op->src[0];
struct ggml_tensor * b = op->src[1];
// disable Q1_0 until implementation
if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) {
return false;
}
if (a->ne[3] != b->ne[3]) {
return false;
}
ggml_type src0_type = op->src[0]->type;
// TODO: The configuration below needs more work to be supported with oneDNN
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) &&
a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) {
@@ -5424,12 +5470,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
// TODO: This specific configuration can fail with oneDNN and needs more debugging
if (!ggml_is_permuted(a) && ggml_is_permuted(b) && b->ne[2] > 1 && b->ne[3] > 1 &&
a->ne[0] > 128 && a->ne[2] == 1 && src0_type == GGML_TYPE_F16) {
printf("zjy 2\n");
return false;
}
return true;
}
case GGML_OP_OUT_PROD:
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
return op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F32 ||
(op->src[0]->type == GGML_TYPE_Q1_0 && op->src[0]->ne[2] == op->src[1]->ne[2] &&
op->src[0]->ne[3] == op->src[1]->ne[3])) &&
op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
@@ -5615,6 +5666,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_IM2COL_3D:
case GGML_OP_UPSCALE:
return true;
case GGML_OP_CONV_3D:
return op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->src[1]->type == GGML_TYPE_F32 &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op->src[1]);
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
@@ -5680,6 +5737,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
GGML_UNUSED(dev);
}
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
bool res = do_ggml_backend_sycl_device_supports_op(dev, op);
GGML_SYCL_DEBUG("[SYCL] call %s op->op=%s op->type=%s -> %s\n", __func__, ggml_op_name(op->op),
ggml_type_name(op->type), res ? "true" : "false");
return res;
}
static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
return false;
+74
View File
@@ -1194,6 +1194,66 @@ static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols(
}
}
static void mul_mat_vec_q1_0_q8_1_sycl(const void * vx, const void * vy,
float * dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK1_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK1_0, QI1_0, block_q1_0,
VDR_Q1_0_Q8_1_MMVQ, vec_dot_q1_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
template <int ncols_dst>
static void mul_mat_vec_q1_0_q8_1_sycl_ncols(
const void * vx, const void * vy, float * dst,
const int ncols, const int nrows,
const int stride_col_y, const int stride_col_dst,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK1_0 == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_ncols<QK1_0, QI1_0, block_q1_0,
VDR_Q1_0_Q8_1_MMVQ, vec_dot_q1_0_q8_1, ncols_dst>(
vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1);
});
});
}
static void mul_mat_vec_q1_0_q8_1_sycl_switch_ncols(
const void * vx, const void * vy, float * dst,
const int ncols, const int nrows, const int ncols_dst,
const int stride_col_y, const int stride_col_dst,
dpct::queue_ptr stream) {
switch (ncols_dst) {
case 1: mul_mat_vec_q1_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break;
case 2: mul_mat_vec_q1_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 3: mul_mat_vec_q1_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 4: mul_mat_vec_q1_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 5: mul_mat_vec_q1_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 6: mul_mat_vec_q1_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 7: mul_mat_vec_q1_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
case 8: mul_mat_vec_q1_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break;
default: GGML_ABORT("unsupported ncols_dst=%d for Q1_0 multi-col MMVQ", ncols_dst);
}
}
static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
@@ -2120,6 +2180,20 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q1_0:
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
const int stride_col_y = src1_padded_col_size / QK8_1;
const int stride_col_dst = dst->ne[0];
GGML_SYCL_DEBUG("Calling mul_mat_vec_q1_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols);
mul_mat_vec_q1_0_q8_1_sycl_switch_ncols(
src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff,
src1_ncols, stride_col_y, stride_col_dst, stream);
return;
} else if (i == 0 || src1_ncols == 1) {
GGML_SYCL_DEBUG("Calling mul_mat_vec_q1_0_q8_1_sycl\n");
mul_mat_vec_q1_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_Q2_K:
if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) {
const int stride_col_y = src1_padded_col_size / QK8_1;
+45 -9
View File
@@ -1,11 +1,12 @@
#include "outprod.hpp"
#include "convert.hpp"
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
const ggml_tensor *src0 = dst->src[0];
const ggml_tensor *src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_Q1_0);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -20,11 +21,31 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
GGML_ASSERT(ne2 % ne02 == 0);
GGML_ASSERT(ne3 % ne03 == 0);
// Get data pointers
const float* src0_d = (const float*)src0->data;
const float* src1_d = (const float*)src1->data;
float* dst_d = (float*)dst->data;
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
ggml_sycl_pool_alloc<float> src0_as_f32(ctx.pool());
int64_t src0_nb02 = nb02;
int64_t src0_nb03 = nb03;
if (src0->type == GGML_TYPE_Q1_0) {
scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
" : converting src0 Q1_0 to fp32");
src0_d = src0_as_f32.alloc(ne00 * ne01 * ne02 * ne03);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst);
GGML_ASSERT(to_fp32_sycl != nullptr);
to_fp32_sycl(src0->data, const_cast<float *>(src0_d), ne00 * ne01 * ne02 * ne03, stream);
// Dequantized src0 buffer is contiguous fp32 [ne00, ne01, ne02, ne03].
src0_nb02 = ne00 * ne01 * (int64_t) sizeof(float);
src0_nb03 = ne00 * ne01 * ne02 * (int64_t) sizeof(float);
}
// GEMM parameters
const float alpha = 1.0f;
@@ -35,12 +56,27 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
const int64_t r2 = ne2 / ne02;
const int64_t r3 = ne3 / ne03;
try {
// Perform matrix multiplication using oneMKL GEMM
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,
ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
}
catch (sycl::exception const& exc) {
// OUT_PROD applies independently to each (i2, i3) destination plane.
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
const int64_t i03 = i3 / r3;
const int64_t i02 = i2 / r2;
const float * src0_plane = (const float *) ((const char *) src0_d + i02 * src0_nb02 + i03 * src0_nb03);
const float * src1_plane = (const float *) ((const char *) src1_d + i2 * nb12 + i3 * nb13);
float * dst_plane = (float *) ((char *) dst_d + i2 * nb2 + i3 * nb3);
// Perform matrix multiplication using oneMKL GEMM
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,
ne0, ne1, ne01, alpha, src0_plane, ne00,
src1_plane, ldb, beta, dst_plane, ne0);
}
}
} catch (sycl::exception const& exc) {
std::cerr << exc.what() << std::endl;
GGML_ASSERT(false);
}
+35
View File
@@ -309,6 +309,41 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh,
vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]);
}
#define VDR_Q1_0_Q8_1_MMVQ 1
#define VDR_Q1_0_Q8_1_MMQ 4
static __dpct_inline__ float
vec_dot_q1_0_q8_1(const void *__restrict__ vbq,
const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq;
const block_q8_1 * bq8_1_chunk = bq8_1 + iqs;
const float d1 = bq1_0->d;
const int v = get_int_from_uint8_aligned(bq1_0->qs, iqs);
int vi_bytes[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int shift = j * 4;
const int bits4 = (v >> shift) & 0x0F;
const int b0 = (bits4 & 0x01) ? 1 : -1;
const int b1 = (bits4 & 0x02) ? 1 : -1;
const int b2 = (bits4 & 0x04) ? 1 : -1;
const int b3 = (bits4 & 0x08) ? 1 : -1;
vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24);
}
int sumi = 0;
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int u = get_int_from_int8_aligned(bq8_1_chunk->qs, j);
sumi = ggml_sycl_dp4a(vi_bytes[j], u, sumi);
}
return d1 * bq8_1_chunk->ds[0] * sumi;
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+270 -27
View File
@@ -6,6 +6,7 @@ import re
import argparse
import statistics
import logging
from typing import Any, Dict, List, Optional
from collections import defaultdict
@@ -25,12 +26,47 @@ COL_MAP = {
}
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?"
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
)
logger = logging.getLogger("ggml-hexagon-profile")
def normalize_event_name(evt_type):
if evt_type == "HVX_COMP":
return "V-COMP"
if evt_type == "HMX_COMP":
return "M-COMP"
# Strip HVX_ or HMX_ prefixes
name = evt_type
if name.startswith("HVX_") or name.startswith("HMX_"):
name = name[4:]
return name.replace("_", "-")
class CycleUnwrapper:
def __init__(self):
self.last_raw = None
self.high_part = 0
def unwrap(self, raw):
if self.last_raw is None:
self.last_raw = raw
return raw
diff = raw - self.last_raw
if diff < -0x80000000:
self.high_part += 0x100000000
elif diff > 0x80000000:
self.high_part -= 0x100000000
self.last_raw = raw
return raw + self.high_part
def parse_log(file_path, pmu_index=None):
try:
if file_path != "-":
@@ -41,35 +77,211 @@ def parse_log(file_path, pmu_index=None):
logger.error(f"file '{file_path}' not found.")
sys.exit(1)
all_ops = []
all_ops: List[Dict[str, Any]] = []
current_op: Optional[Dict[str, Any]] = None
timestamp_pattern = re.compile(r"^(?P<min>\d+)\.(?P<sec>\d+)\.(?P<ms>\d+)\.(?P<us>\d+)\s+[A-Z]\s+")
unwrapper = CycleUnwrapper()
for line in f:
match = op_pattern.search(line)
if not match: continue
ts_match = timestamp_pattern.match(line)
abs_usec = 0
if ts_match:
abs_usec = (
(int(ts_match.group('min')) * 60 + int(ts_match.group('sec'))) * 1000000
+ int(ts_match.group('ms')) * 1000
+ int(ts_match.group('us'))
)
pmu_raw = match.group('pmu')
pmu_val = None
if pmu_raw and pmu_index is not None:
try:
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
if len(pmu_list) > pmu_index:
pmu_val = pmu_list[pmu_index]
except (ValueError, IndexError):
pmu_val = None
op_match = op_pattern.search(line)
if op_match:
pmu_raw = op_match.group('pmu')
pmu_val = None
if pmu_raw and pmu_index is not None:
try:
pmu_list = [int(x.strip()) for x in pmu_raw.split(',')]
if len(pmu_list) > pmu_index:
pmu_val = pmu_list[pmu_index]
except (ValueError, IndexError):
pmu_val = None
all_ops.append({
'name': match.group('op_name'),
'dims': match.group('dims').strip(),
'types': match.group('types').strip(),
'usec': int(match.group('usec')),
'cycles': int(match.group('cycles')),
'pmu_val': pmu_val
})
evt_raw = op_match.group('evt')
evt_val = None
if evt_raw:
try:
evt_val = [int(x.strip()) for x in evt_raw.split(',')]
except ValueError:
evt_val = None
cycles_start_raw = op_match.group('start')
unwrapped_cycles_start = None
if cycles_start_raw:
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
idx = line.find("profile-op ")
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip(),
'types': op_match.group('types').strip(),
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
'unwrapped_cycles_start': unwrapped_cycles_start,
'pmu_val': pmu_val,
'evt_val': evt_val,
'abs_usec': abs_usec,
'trace_events': []
}
all_ops.append(current_op)
continue
trace_match = trace_pattern.search(line)
if trace_match and current_op:
if trace_match.group('op_name') == current_op['name']:
raw_cyc = int(trace_match.group('cycles'))
current_op['trace_events'].append({
'thread': int(trace_match.group('thread')),
'event': trace_match.group('event'),
'info': int(trace_match.group('info')),
'cycles': raw_cyc,
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
'state': trace_match.group('state')
})
f.close()
return all_ops
def print_ascii_timeline(op_name, dims, types, usec, cycles, events, evt_val=None):
evt_str = ""
if evt_val:
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
logger.info("=" * 100)
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
logger.info("=" * 100)
events = sorted(events, key=lambda e: e['cycles'])
if not events:
logger.info(" No trace events recorded.")
return
min_cycles = events[0]['cycles']
logger.info("Cycles %-30s" % "EventDetails" + " ".join(f"T{i:<2}" for i in range(10)) + " HMX")
logger.info("-" * 100)
thread_stacks = [[] for _ in range(11)]
for e in events:
t = e['thread']
if t < 0 or t > 10:
continue
if e['cycles'] >= min_cycles:
rel_cycles = e['cycles'] - min_cycles
else:
rel_cycles = (e['cycles'] + 0x100000000) - min_cycles
state = e['state']
evt_type = e['event']
# Determine char representing the event
norm_evt = normalize_event_name(evt_type)
char = '?'
if norm_evt == 'V-COMP':
char = 'V'
elif norm_evt == 'M-COMP':
char = 'H'
elif norm_evt == 'A-QUANT':
char = 'Q'
elif norm_evt == 'A-PREP':
char = 'A'
elif norm_evt == 'W-DEQUANT':
char = 'D'
elif norm_evt == 'O-PROC':
char = 'O'
elif norm_evt == 'W-PREP':
char = 'P'
elif norm_evt == 'DMA':
char = 'M'
if state == 'start':
thread_stacks[t].append(char)
elif state == 'stop':
if thread_stacks[t]:
if thread_stacks[t][-1] == char:
thread_stacks[t].pop()
elif char in thread_stacks[t]:
thread_stacks[t].remove(char)
else:
thread_stacks[t].pop()
cols = []
for i in range(11):
if thread_stacks[i]:
cols.append(f"[{thread_stacks[i][-1]}]")
else:
cols.append(" | ")
evt_desc = f"T{t}: {evt_type} {state} ({e['info']})"
logger.info(f"{rel_cycles:10d} %-30s" % evt_desc + " ".join(cols[:10]) + " " + cols[10])
logger.info("-" * 100)
def print_ascii_summary(op_name, dims, types, usec, cycles, events, evt_val=None):
evt_str = ""
if evt_val:
evt_str = " - evt [" + ",".join(str(x) for x in evt_val) + "]"
logger.info("=" * 100)
logger.info(f"{op_name} ({dims} : {types}) - {usec} usec {cycles} cycles{evt_str}")
logger.info("=" * 100)
events = sorted(events, key=lambda e: e['cycles'])
if not events:
logger.info(" No trace events recorded.")
return
active_starts = {}
thread_totals = defaultdict(lambda: defaultdict(int))
for e in events:
t = e['thread']
evt = e['event']
info = e['info']
cyc = e['cycles']
state = e['state']
key = (t, evt, info)
if state == 'start':
active_starts[key] = cyc
elif state == 'stop':
if key in active_starts:
start_cyc = active_starts[key]
del active_starts[key]
if cyc >= start_cyc:
dur = cyc - start_cyc
else:
dur = (cyc + 0x100000000) - start_cyc
norm_evt = normalize_event_name(evt)
thread_totals[t][norm_evt] += dur
for t in sorted(thread_totals.keys()):
thread_name = f"Thread {t} (HVX)" if t != 10 else "Thread 10 (HMX)"
sorted_evts = sorted(thread_totals[t].items(), key=lambda item: item[0])
evt_strs = []
for evt, dur in sorted_evts:
pct = (dur / cycles * 100) if cycles > 0 else 0
evt_strs.append(f"{evt} {dur} ({pct:.1f}%)")
logger.info(f" {thread_name:<16}: " + " | ".join(evt_strs))
def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
if not ops:
logger.info("No valid records found.")
@@ -115,7 +327,6 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
# Sorting logic
actual_sort_key = COL_MAP[sort_col][2]
# We sort numeric fields descending, strings (op/dims) ascending
is_numeric = actual_sort_key.startswith("_") or actual_sort_key == "count"
sorted_groups = sorted(group_stats, key=lambda x: x[actual_sort_key], reverse=is_numeric)[:top_n]
@@ -132,7 +343,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
if "pmu" in col_name and pmu_name:
header_text = header_text.replace("PMU", pmu_name)
natural_width = max([len(row[data_key]) for row in sorted_groups] + [len(header_text)])
natural_width = max([len(str(row[data_key])) for row in sorted_groups] + [len(header_text)])
target_width = width_overrides.get(col_name, natural_width)
if target_width == 0:
@@ -152,7 +363,7 @@ def generate_report(ops, top_n, width_overrides, sort_col, pmu_name=None):
for group in sorted_groups:
row_vals = []
for i, key in enumerate(final_keys):
val = group[key]
val = str(group[key])
if len(val) > final_widths[i]:
val = val[:final_widths[i] - 3] + "..."
row_vals.append(f"{val:<{final_widths[i]}}")
@@ -167,12 +378,18 @@ def main():
parser.add_argument("--pmu-index", type=int)
parser.add_argument("--pmu-name", type=str)
parser.add_argument("--width", action='append', default=['dims:40'], help="Override column width, e.g. --width dims:50")
parser.add_argument("--timeline", type=str, nargs='?', const='summary', choices=["summary", "diagram"],
help="Output ASCII art event summary or timing diagram (default: summary)")
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
group = parser.add_mutually_exclusive_group()
group.add_argument("--head", type=int, help="Limit to first N ops")
group.add_argument("--tail", type=int, help="Limit to last N ops")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(message)s')
# Sort validation: can't sort by PMU if index isn't provided
if "pmu" in args.sort and args.pmu_index is None:
logger.error(f"Cannot sort by '{args.sort}' without --pmu-index.")
sys.exit(1)
@@ -188,7 +405,33 @@ def main():
final_pmu_name = (args.pmu_name or f"#{args.pmu_index}") if args.pmu_index is not None else None
ops = parse_log(args.logfile, pmu_index=args.pmu_index)
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
if args.filter:
try:
filter_re = re.compile(args.filter)
except re.error as e:
logger.error(f"Invalid regex filter: {e}")
sys.exit(1)
ops = [op for op in ops if filter_re.search(op['op_text'])]
if args.head is not None:
ops = ops[:args.head]
elif args.tail is not None:
ops = ops[-args.tail:]
if args.timeline:
logger.info(f"\n# ASCII Timing {args.timeline.capitalize()}\n")
printed_cnt = 0
for op in ops:
if args.timeline == "summary":
print_ascii_summary(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
elif args.timeline == "diagram":
print_ascii_timeline(op['name'], op['dims'], op['types'], op['usec'], op['cycles'], op['trace_events'], op.get('evt_val'))
printed_cnt += 1
if printed_cnt >= args.top:
break
else:
generate_report(ops, args.top, overrides, args.sort, pmu_name=final_pmu_name)
if __name__ == "__main__":
+463
View File
@@ -0,0 +1,463 @@
#!/usr/bin/env python3
import sys
import os
import re
import argparse
import statistics
import logging
from typing import Any, Dict, List, Optional
from collections import defaultdict
logger = logging.getLogger("ggml-hexagon-trace")
op_pattern = re.compile(
r"profile-op\s+(?P<op_name>[A-Z_0-9+]+):\s+.*?\s+:\s+(?P<dims>[\d:x\s\->!]+)\s+:\s+(?P<types>[a-z\d_\s\->x]+)\s+:\s+(?P<strides>[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P<usec>\d+)\s+(?:op-)?cycles\s+(?P<cycles>\d+)(?:\s+start\s+(?P<start>\d+))?(?:\s+mhz\s+(?P<mhz>[\d.]+))?(?:\s+pmu\s+\[(?P<pmu>[\d,\s]+)\])?(?:\s+evt\s+\[(?P<evt>[\d,\s]+)\])?"
)
trace_pattern = re.compile(
r"trace-op\s+(?P<op_name>[A-Z_0-9+]+):\s+thread\s+(?P<thread>\d+)\s+event\s+(?P<event>[A-Z_0-9\-]+)\s+info\s+(?P<info>\d+)\s+(?P<state>start|stop)\s+(?P<cycles>\d+)"
)
def normalize_event_name(evt_type):
if evt_type == "HVX_COMP":
return "V-COMP"
if evt_type == "HMX_COMP":
return "M-COMP"
name = evt_type
if name.startswith("HVX_") or name.startswith("HMX_"):
name = name[4:]
return name.replace("_", "-")
class CycleUnwrapper:
def __init__(self):
self.last_raw = None
self.high_part = 0
def unwrap(self, raw):
if self.last_raw is None:
self.last_raw = raw
return raw
diff = raw - self.last_raw
if diff < -0x80000000:
self.high_part += 0x100000000
elif diff > 0x80000000:
self.high_part -= 0x100000000
self.last_raw = raw
return raw + self.high_part
def parse_log(file_path):
try:
if file_path != "-":
f = open(file_path, 'r', encoding='utf-8', errors='ignore')
else:
f = os.fdopen(0, 'r', encoding='utf-8', errors='ignore')
except FileNotFoundError:
logger.error(f"file '{file_path}' not found.")
sys.exit(1)
all_ops: List[Dict[str, Any]] = []
current_op: Optional[Dict[str, Any]] = None
unwrapper = CycleUnwrapper()
line_idx = 0
for line in f:
line_idx += 1
op_match = op_pattern.search(line)
if op_match:
cycles_start_raw = op_match.group('start')
unwrapped_cycles_start = None
if cycles_start_raw:
unwrapped_cycles_start = unwrapper.unwrap(int(cycles_start_raw))
idx = line.find("profile-op ")
op_text = line[idx + 11:].strip() if idx != -1 else line.strip()
current_op = {
'name': op_match.group('op_name'),
'dims': op_match.group('dims').strip() if op_match.group('dims') else '',
'types': op_match.group('types').strip() if op_match.group('types') else '',
'strides': op_match.group('strides').strip() if op_match.group('strides') else '',
'op_text': op_text,
'usec': int(op_match.group('usec')),
'cycles': int(op_match.group('cycles')),
'cycles_start': int(cycles_start_raw) if cycles_start_raw else None,
'unwrapped_cycles_start': unwrapped_cycles_start,
'trace_events': [],
'line_num': line_idx
}
all_ops.append(current_op)
continue
trace_match = trace_pattern.search(line)
if trace_match and current_op:
if trace_match.group('op_name') == current_op['name']:
raw_cyc = int(trace_match.group('cycles'))
current_op['trace_events'].append({
'thread': int(trace_match.group('thread')),
'event': trace_match.group('event'),
'info': int(trace_match.group('info')),
'cycles': raw_cyc,
'unwrapped_cycles': unwrapper.unwrap(raw_cyc),
'state': trace_match.group('state')
})
f.close()
return all_ops
# --- Simple protobuf encoder ---
def write_varint(val):
if val < 0:
val = (1 << 64) + val
res = bytearray()
while True:
towrite = val & 0x7f
val >>= 7
if val > 0:
res.append(towrite | 0x80)
else:
res.append(towrite)
break
return bytes(res)
def pb_field(num, wire, data):
return write_varint((num << 3) | wire) + data
def pb_varint(num, val):
return pb_field(num, 0, write_varint(val))
def pb_length_delimited(num, data):
return pb_field(num, 2, write_varint(len(data)) + data)
def pb_string(num, text):
return pb_length_delimited(num, text.encode('utf-8'))
# Message Encoders
def make_process_descriptor(pid, name):
return pb_varint(1, pid) + pb_string(6, name)
def make_thread_descriptor(pid, tid, name, sort_index=None):
payload = pb_varint(1, pid) + pb_varint(2, tid) + pb_string(5, name)
if sort_index is not None:
payload += pb_varint(3, sort_index)
return payload
def make_track_descriptor(uuid, name=None, parent_uuid=None, thread=None, process=None, sibling_merge_behavior=None, child_ordering=None, sibling_order_rank=None):
payload = pb_varint(1, uuid)
if name is not None:
payload += pb_string(2, name)
if parent_uuid is not None:
payload += pb_varint(5, parent_uuid)
if process is not None:
payload += pb_length_delimited(3, process)
if thread is not None:
payload += pb_length_delimited(4, thread)
if sibling_merge_behavior is not None:
payload += pb_varint(15, sibling_merge_behavior)
if child_ordering is not None:
payload += pb_varint(11, child_ordering)
if sibling_order_rank is not None:
payload += pb_varint(12, sibling_order_rank)
return payload
def make_debug_annotation(name, string_val=None, int_val=None):
payload = pb_string(10, name)
if string_val is not None:
payload += pb_string(6, string_val)
elif int_val is not None:
payload += pb_varint(4, int_val)
return payload
def make_track_event(event_type, track_uuid, name=None, category=None, debug_annotations=None):
payload = pb_varint(9, event_type)
payload += pb_varint(11, track_uuid)
if name is not None:
payload += pb_string(23, name)
if category is not None:
payload += pb_string(22, category)
if debug_annotations is not None:
for da in debug_annotations:
payload += pb_length_delimited(4, da)
return payload
def make_trace_packet(timestamp, track_event=None, track_descriptor=None, seq_id=1):
payload = pb_varint(8, timestamp)
payload += pb_varint(10, seq_id)
if track_event is not None:
payload += pb_length_delimited(11, track_event)
if track_descriptor is not None:
payload += pb_length_delimited(60, track_descriptor)
return payload
def write_trace_packet_to_file(f, packet_bytes):
# Write as field 1 of top-level Trace message
f.write(pb_length_delimited(1, packet_bytes))
# --- End Protobuf Encoder ---
def generate_perfetto_trace(filtered_ops, output_path):
if not filtered_ops:
logger.warning("No operators found after filtering.")
return
# Compute average frequency
frequencies = []
for op in filtered_ops:
if op['usec'] > 0 and op['cycles'] > 0:
frequencies.append(op['cycles'] / op['usec'])
avg_freq_mhz = statistics.mean(frequencies) if frequencies else 1000.0
if avg_freq_mhz <= 0:
avg_freq_mhz = 1000.0
# Assign start and end cycles to each operator
for op in filtered_ops:
op['start_cycles'] = op['unwrapped_cycles_start']
op['end_cycles'] = op['start_cycles'] + op['cycles']
global_min_cyc = min(op['start_cycles'] for op in filtered_ops if op['start_cycles'] is not None)
# Process events
completed_events = []
for op in filtered_ops:
events = op['trace_events']
if not events:
continue
events = sorted(events, key=lambda e: e['unwrapped_cycles'])
active_starts = {}
for e in events:
t = e['thread']
evt = e['event']
info = e['info']
state = e['state']
cyc = e['unwrapped_cycles']
key = (t, evt, info)
if state == 'start':
active_starts[key] = cyc
elif state == 'stop':
if key in active_starts:
start_cyc = active_starts[key]
del active_starts[key]
completed_events.append({
'thread': t,
'event': evt,
'info': info,
'start_cyc': start_cyc,
'end_cyc': cyc,
'op_name': op['name']
})
completed_events.sort(key=lambda e: e['start_cyc'])
# Convert event times to microseconds and apply clamp rounded to 1ns resolution (3 decimals)
for e in completed_events:
start_us = (e['start_cyc'] - global_min_cyc) / avg_freq_mhz
dur_us = (e['end_cyc'] - e['start_cyc']) / avg_freq_mhz
e['ts_ns'] = int(round(start_us * 1000))
e['dur_ns'] = int(round(max(dur_us, 0.1) * 1000))
# Allocate slots (sub-tracks) to prevent overlaps on same virtual track
active_slots = defaultdict(list)
for e in completed_events:
t = e['thread']
evt = e['event']
ts = e['ts_ns']
dur = e['dur_ns']
norm_evt = normalize_event_name(evt)
if norm_evt == "DMA":
track_key = (t, "DMA")
elif t == 10:
track_key = (t, "HMX")
else:
track_key = (t, "HVX")
slots = active_slots[track_key]
allocated_slot = -1
for idx, slot_end_ns in enumerate(slots):
if ts >= slot_end_ns:
slots[idx] = ts + dur
allocated_slot = idx
break
if allocated_slot == -1:
slots.append(ts + dur)
allocated_slot = len(slots) - 1
e['slot'] = allocated_slot
# Generate Track IDs and track definitions
used_tracks = {}
for e in completed_events:
t = e['thread']
evt = e['event']
slot = e['slot']
norm_evt = normalize_event_name(evt)
if norm_evt == "DMA":
track_evt = "DMA"
evt_id = 1
elif t == 10:
track_evt = "HMX"
evt_id = 3
else:
track_evt = "HVX"
evt_id = 2
t_sort = 1 if t == 10 else t + 2
# Unique UUID for each sub-track
if t == 10:
uuid = 20 # HMX thread track UUID
else:
uuid = int(t_sort * 1000000 + evt_id * 1000 + slot)
e['uuid'] = uuid
used_tracks[uuid] = (t, track_evt, slot)
with open(output_path, "wb") as f:
# Define Process with EXPLICIT child sorting
proc_desc = make_process_descriptor(1, "HTP NPU")
proc_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(1, process=proc_desc, child_ordering=3))
write_trace_packet_to_file(f, proc_packet)
# Define Operators Track (UUID = 2) as a thread track at rank 1, tid 8
op_thread_desc = make_thread_descriptor(1, 8, "Ops", sort_index=1)
op_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(2, parent_uuid=1, thread=op_thread_desc))
write_trace_packet_to_file(f, op_packet)
# Define HMX Thread Track (UUID = 20) at rank 2, tid 9
hmx_thread_desc = make_thread_descriptor(1, 9, "HMX", sort_index=2)
hmx_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(20, parent_uuid=1, thread=hmx_thread_desc))
write_trace_packet_to_file(f, hmx_packet)
# Define Thread Tracks (T0, T1, ..., T9)
unique_threads = sorted(list(set(t for (t, _, _) in used_tracks.values() if t != 10)))
for t in unique_threads:
thread_uuid = 10 + t
thread_name = f"T{t}"
# Sort order starts from index 3 (T0 -> 3, T1 -> 4, etc.)
sort_index = 3 + t
tid = 10 + t
thread_desc = make_thread_descriptor(1, tid, thread_name, sort_index=sort_index)
thread_packet = make_trace_packet(0, track_descriptor=make_track_descriptor(
thread_uuid,
parent_uuid=1,
thread=thread_desc,
sibling_order_rank=sort_index,
child_ordering=3 # Explicit child sorting for sub-tracks
))
write_trace_packet_to_file(f, thread_packet)
# Define Track descriptors for sub-tracks parented to thread tracks
for uuid in sorted(used_tracks.keys()):
if uuid == 20:
continue
t, evt, slot = used_tracks[uuid]
name = f"T{t} {evt}"
rank = 0 if evt == "HVX" else 1
parent_thread_uuid = 10 + t
# Sibling merge behavior: 1 (SIBLING_MERGE_BEHAVIOR_BY_TRACK_NAME)
track_desc = make_track_descriptor(
uuid=uuid,
name=name,
parent_uuid=parent_thread_uuid,
sibling_merge_behavior=1,
sibling_order_rank=rank
)
track_packet = make_trace_packet(0, track_descriptor=track_desc)
write_trace_packet_to_file(f, track_packet)
# Emit Operators
last_op_end_ns = 0
for op in filtered_ops:
op_start_ns = int(round(((op['start_cycles'] - global_min_cyc) / avg_freq_mhz) * 1000))
op_dur_ns = int(round((op['cycles'] / avg_freq_mhz) * 1000))
if op_start_ns < last_op_end_ns:
op_start_ns = last_op_end_ns
clamped_dur = max(op_dur_ns, 100) # Clamp to 100ns (0.1us)
# Debug annotations for Ops
debug_annots = []
if 'line_num' in op:
debug_annots.append(make_debug_annotation("line", int_val=op['line_num']))
if 'strides' in op and op['strides']:
debug_annots.append(make_debug_annotation("strides", string_val=op['strides']))
# Slice Begin
evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots)
packet_begin = make_trace_packet(op_start_ns, track_event=evt_begin)
write_trace_packet_to_file(f, packet_begin)
# Slice End
evt_end = make_track_event(2, 2)
packet_end = make_trace_packet(op_start_ns + clamped_dur, track_event=evt_end)
write_trace_packet_to_file(f, packet_end)
last_op_end_ns = op_start_ns + clamped_dur
# Emit Thread Trace Events
for e in completed_events:
norm_name = normalize_event_name(e['event'])
name = f"DMA {e['info']}" if norm_name == "DMA" else norm_name
# Slice Begin
evt_begin = make_track_event(1, e['uuid'], name=name, category="trace")
packet_begin = make_trace_packet(e['ts_ns'], track_event=evt_begin)
write_trace_packet_to_file(f, packet_begin)
# Slice End
evt_end = make_track_event(2, e['uuid'])
packet_end = make_trace_packet(e['ts_ns'] + e['dur_ns'], track_event=evt_end)
write_trace_packet_to_file(f, packet_end)
logger.info(f"Successfully generated Perfetto trace at {output_path}")
def main():
parser = argparse.ArgumentParser(description="Convert Hexagon Op profile logs to native Perfetto Protobuf traces.")
parser.add_argument("logfile", help="Path to hex-log profile file")
parser.add_argument("-o", "--output", default="optrace.perfetto-trace", help="Output trace file path (default: optrace.perfetto-trace)")
parser.add_argument("--filter", type=str, help="Regex filter matching against the original profile-op line")
group = parser.add_mutually_exclusive_group()
group.add_argument("--head", type=int, help="Limit to first N ops")
group.add_argument("--tail", type=int, help="Limit to last N ops")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format='%(message)s')
ops = parse_log(args.logfile)
if args.filter:
try:
filter_re = re.compile(args.filter)
except re.error as e:
logger.error(f"Invalid regex filter: {e}")
sys.exit(1)
ops = [op for op in ops if filter_re.search(op['op_text'])]
if args.head is not None:
ops = ops[:args.head]
elif args.tail is not None:
ops = ops[-args.tail:]
generate_perfetto_trace(ops, args.output)
if __name__ == "__main__":
main()
+1 -1
View File
@@ -1 +1 @@
3af5f5760e19a96427f5f7a93b79cbdf3d4b265b
707321c4cf6d21cb4bc831aa8b687dbf01a521ce
+23 -4
View File
@@ -20,6 +20,7 @@ set(LLAMA_UI_GZIP "" CACHE STRING "Apply gzip compress to assets to save ban
set(DIST_DIR "${UI_BINARY_DIR}/dist")
set(SRC_DIST_DIR "${UI_SOURCE_DIR}/dist")
set(WORK_DIR "${UI_BINARY_DIR}/ui-src")
set(STAMP_FILE "${UI_BINARY_DIR}/.ui-stamp")
set(UI_CPP "${UI_BINARY_DIR}/ui.cpp")
set(UI_H "${UI_BINARY_DIR}/ui.h")
@@ -64,6 +65,22 @@ function(npm_build_should_skip out_var)
set(${out_var} TRUE PARENT_SCOPE)
endfunction()
function(stage_sources)
if(EXISTS "${WORK_DIR}")
file(GLOB staged RELATIVE "${WORK_DIR}" "${WORK_DIR}/*")
list(REMOVE_ITEM staged "node_modules")
foreach(entry ${staged})
file(REMOVE_RECURSE "${WORK_DIR}/${entry}")
endforeach()
endif()
file(COPY "${UI_SOURCE_DIR}/"
DESTINATION "${WORK_DIR}"
NO_SOURCE_PERMISSIONS
PATTERN "node_modules" EXCLUDE
)
endfunction()
function(npm_build out_var)
set(${out_var} FALSE PARENT_SCOPE)
@@ -89,14 +106,16 @@ function(npm_build out_var)
return()
endif()
stage_sources()
# npm writes node_modules/.package-lock.json on every successful install,
# so a package-lock.json newer than this marker means node_modules is stale
set(NPM_MARKER "${UI_SOURCE_DIR}/node_modules/.package-lock.json")
set(NPM_MARKER "${WORK_DIR}/node_modules/.package-lock.json")
set(need_install FALSE)
if(NOT EXISTS "${NPM_MARKER}")
set(need_install TRUE)
else()
file(TIMESTAMP "${UI_SOURCE_DIR}/package-lock.json" lock_ts)
file(TIMESTAMP "${WORK_DIR}/package-lock.json" lock_ts)
file(TIMESTAMP "${NPM_MARKER}" marker_ts)
if(lock_ts STRGREATER marker_ts)
set(need_install TRUE)
@@ -107,7 +126,7 @@ function(npm_build out_var)
message(STATUS "UI: running npm install")
execute_process(
COMMAND ${NPM_EXECUTABLE} install
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE rc
ERROR_VARIABLE err
)
@@ -124,7 +143,7 @@ function(npm_build out_var)
execute_process(
COMMAND ${CMAKE_COMMAND} -E env "LLAMA_UI_OUT_DIR=${DIST_DIR}" "LLAMA_UI_VERSION=${HF_VERSION}" "LLAMA_BUILD_NUMBER=${LLAMA_BUILD_NUMBER}"
${NPM_EXECUTABLE} run build
WORKING_DIRECTORY "${UI_SOURCE_DIR}"
WORKING_DIRECTORY "${WORK_DIR}"
RESULT_VARIABLE rc
ERROR_VARIABLE err
)
+1 -1
View File
@@ -1382,7 +1382,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
const auto & hparams = model.hparams;
// eagle3/DFlash: features as encoder input, and non-draft paths fall back to model's input dim
const int64_t n_embd = hparams.n_embd_inp();
const int64_t n_embd = hparams.n_embd_inp_enc();
const int64_t n_vocab = model.vocab.n_tokens();
// note: during encode, we always pass the full sequence starting from pos = 0
+4
View File
@@ -104,6 +104,10 @@ uint32_t llama_hparams::n_embd_inp() const {
return n_embd_inp;
}
uint32_t llama_hparams::n_embd_inp_enc() const {
return n_embd_inp_enc_impl > 0 ? n_embd_inp_enc_impl : n_embd_inp();
}
uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}
+7
View File
@@ -189,6 +189,10 @@ struct llama_hparams {
// input embedding dimension (0 = use n_embd)
uint32_t n_embd_inp_impl = 0;
// encoder input embedding dimension (0 = use n_embd_inp())
// e.g. the eagle3 encoder fuses target_layers * target_hidden features
uint32_t n_embd_inp_enc_impl = 0;
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;
@@ -305,6 +309,9 @@ struct llama_hparams {
// dimension of main + auxiliary input embeddings
uint32_t n_embd_inp() const;
// dimension of the encoder input embeddings
uint32_t n_embd_inp_enc() const;
// dimension of output embeddings
uint32_t n_embd_out() const;
+1 -1
View File
@@ -249,7 +249,7 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama
}
// if using single GPU mode, remove all except the main GPU
if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
if (params.split_mode == LLAMA_SPLIT_MODE_NONE && !model->devices.empty()) {
if (params.main_gpu < 0) {
model->devices.clear();
} else {
+4 -4
View File
@@ -19,7 +19,7 @@ void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_TARGET_HIDDEN_SIZE, n_embd_tgt);
LLAMA_LOG_INFO("%s: EAGLE3 n_embd_tgt = %u (draft n_embd = %u)\n", __func__, n_embd_tgt, hparams.n_embd);
hparams.n_embd_inp_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt;
hparams.n_embd_inp_enc_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt;
// eagle3 norm_before_residual (optional, default false)
// compatible with Readhat eagle3 speculator model
@@ -34,7 +34,7 @@ void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) {
void llama_model_eagle3::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
const int64_t n_embd_inp = hparams.n_embd_inp();
const int64_t n_embd_inp = hparams.n_embd_inp_enc();
const int64_t n_embd_attn_input = 2 * n_embd;
// Get vocab size from the d2t tensor in the GGUF file (optional - only needed if eagle3 has different vocab_size than target)
@@ -109,8 +109,8 @@ ggml_tensor * llama_model_eagle3::graph<true>::build_inp_embd_enc() const {
// Input: Target model features (3 layers concatenated: low, mid, high)
// Data will be provided via ubatch->embd in encode_eagle3_features()
auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp());
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,hparams.n_embd_inp(), n_tokens);
auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp_enc());
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp_enc(), n_tokens);
ggml_set_input(inp_target->embd);
cur = inp_target->embd;
+6 -8
View File
@@ -6,11 +6,10 @@ Apply LORA adapters to base model and export the resulting model.
usage: llama-export-lora [options]
options:
-m, --model model path from which to load base model (default '')
--lora FNAME path to LoRA adapter (can be repeated to use multiple adapters)
--lora-scaled FNAME S path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)
-t, --threads N number of threads to use during computation (default: 4)
-o, --output FNAME output file (default: 'ggml-lora-merged-f16.gguf')
-m, --model FNAME model path from which to load base model
--lora FNAME path to LoRA adapter (use comma-separated values to load multiple adapters)
--lora-scaled FNAME:SCALE,... path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)
-o, --output, --output-file FNAME output file (default: 'ggml-lora-merged-f16.gguf')
```
For example:
@@ -22,12 +21,11 @@ For example:
--lora lora-open-llama-3b-v2-english2tokipona-chat-LATEST.gguf
```
Multiple LORA adapters can be applied by passing multiple `--lora FNAME` or `--lora-scaled FNAME S` command line parameters:
Multiple LORA adapters can be applied by passing comma-separated values to `--lora FNAME` or `--lora-scaled FNAME:SCALE,...`:
```bash
./bin/llama-export-lora \
-m your_base_model.gguf \
-o your_merged_model.gguf \
--lora-scaled lora_task_A.gguf 0.5 \
--lora-scaled lora_task_B.gguf 0.5
--lora-scaled lora_task_A.gguf:0.5,lora_task_B.gguf:0.5
```

Some files were not shown because too many files have changed in this diff Show More