Compare commits

...

79 Commits

Author SHA1 Message Date
YiChen Lv bec3083830 metal : per-op source split + parallel compile (#24021)
* preliminary extract common header

* op source split

* split metallib into 8 libs && load in parallel

* derive kernel->library routing from functionNames

* x-macro lib list + underscore filenames, dedup QK_NL, MRC fixes

* op source split 8 to 20

* improve robustness of source fallback

* clean up

* change bool -> atomic_bool

* only prepend headers that source actually includes

* no semaphore, use GCD global queue

* dedup library compile path, fix NSError lifetime, rename gla

* relocate upstream concat/rope_back/repeat kernel changes into split files

* move ggml-common.h from common.h into dequantize.h to shrink binary size

---------

Co-authored-by: lvyichen <lvyichen@stepfun.com>
2026-06-22 14:15:48 +03:00
Neo Zhang f8cc15f163 [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838)
* support bf16 on bin_bcast OP and unary OPs

* support the older Intel compiler than 2026.0
2026-06-22 14:09:02 +03:00
Tim Neumann 37957e8531 sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) 2026-06-22 14:08:32 +03:00
Pascal d0f9d2e5ac server: fix edit_file crash on append at end of file (line_start -1) (#24893)
line_start -1 normalized to n+1, so append inserted at lines.begin() + n + 1,
one past end() -> heap-buffer-overflow in vector::_M_range_insert.

Normalize -1 to n (insert at end()), restrict -1 to append mode and reject it
for replace/delete instead of silently clobbering the last line. Parenthesize
the insert offset so empty-file append computes the position as int first,
avoiding a transient begin() - 1 on a null vector data pointer.
2026-06-22 10:55:28 +02:00
aafsmarak 0ef6f06d55 docs/android.md: Add dependency libandroid-spawn for building in termux (#21812)
Fixes https://github.com/ggml-org/llama.cpp/issues/18615
2026-06-22 05:48:31 +02:00
Aldehir Rojas 52b3df0023 common/peg : implement ac parser for stricter grammar generation (#24869)
* common/peg : implement ac parser

* cont : extract functions

* cont : tidy up

* cont : remove a test

* cont : move ac() def
2026-06-21 16:20:58 -05:00
Xuan-Son Nguyen 7c082bc417 server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list

* improve

* nits

* nits 2
2026-06-21 17:36:52 +02:00
Xuan-Son Nguyen bddfd2b113 server: refactor batch construction (#24843)
* server: refactor batch construction

* wip

* wip 2

* wip 3

* wip 4

* add abort_all_slots

* handle batch full more carefully

* fix assert

* rm debug log

* small nits

* (debug) add timings

* debug: force llama_synchronize for accurate timings

* address comments

* disable DEBUG_TIMINGS
2026-06-21 14:16:11 +02:00
Xuan-Son Nguyen 0d135df48c mtmd: fix mtmd_get_memory_usage (#24867) 2026-06-21 14:12:15 +02:00
Sigbjørn Skjæret bf533823cd jinja : implement call statement (#24847)
* implement call statement

* undo unintended change

* de-lambda

* simplify

* move caller context inside function handler
2026-06-21 14:04:52 +02:00
Xuan-Son Nguyen 2f89acc2bc mtmd: add load progress callback (#24865) 2026-06-21 13:40:52 +02:00
Xuan-Son Nguyen bfa3219177 server: add "verbose" field to schema (#24864) 2026-06-21 13:03:14 +02:00
Xuan-Son Nguyen d6d899580d server: real-time model load progress tracking via /models/sse (#24828)
* server: real-time model load progress tracking via /models/sse

* update docs

* add mutex for notify_to_router

* correct docs
2026-06-21 11:58:14 +02:00
Georgi Gerganov 8a118ee86c minor : clean-up whitespaces (#24862)
[no ci]
2026-06-21 11:37:12 +03:00
YiChen Lv d789527482 spec : Support Step3.5/3.7 flash mtp3 (#24340)
* add mtp_layer_offset + include nextn flags in graph reuse

* add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API

* offset head select + require all MTP blocks

* speculative multi-head process()

* speculative multi-head draft()

* gather outputs via inp_out_ids

* cleanup

* fix core

* minor cleanup

* merged draft_multi_head into draft()

* mtp rename nextn

* Apply suggestions from code review

Co-authored-by: Aman Gupta <amangupta052@gmail.com>

* clean-up comments

* fix for multi seq

* apply suggestions && chain-heads comment

* add a reference for chain_heads discussion

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
2026-06-21 11:33:18 +03:00
Aldehir Rojas 063d9c156e common/peg : refactor until gbnf grammar generation (#24839)
* common/peg : refactor until gbnf grammar into an ac automaton

* cont : add a test with multiple strings

* cont : pad state with 0s so rules line up

* cont : clean up comments

* cont : use set everywhere

* cont : inline state num string padding

* cont : add a ref to PR

* cont : fix regression in server-tools.cpp
2026-06-20 21:15:06 -05:00
Aldehir Rojas c57607016a common/json-schema-to-grammar : align spacing rules with parsers (#24835) 2026-06-20 17:43:04 -05:00
Guanhuai Zhang 4a80943174 fix(hexagon): use padded stride for ssm-conv weights (#24470) 2026-06-20 14:58:49 -07:00
Adrien Gallouët 84de01a1f1 llama : use LLM_KV for quantization_version & file_type (#24802)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-20 20:07:01 +02:00
Xuan-Son Nguyen 75f460ac28 arg: try fixing test-args-parser randomly fails (#24826)
* arg: try fixing test-args-parser randomly fails

* return ref

* try triggering the workflow

* exception wrapper

* wip

* test

* test 2

* arg: guard win32 utf8 argv override

make_utf8_argv rebuilds argv from GetCommandLineW to fix utf8 handling of
non ascii arguments on windows. the override runs unconditionally inside
common_params_parse, so it also clobbers a programmatic argv passed by a
caller. test-arg-parser builds a synthetic argv but then sees the real
process command line instead, the model argument is never parsed, and the
assert that expects success aborts via fastfail (0xC0000409). this shows up
as a random failure in the openvino windows workflow.

only override argv when its length matches the caller argc, so the utf8
repair still applies to real binaries while a programmatic argv stays intact.

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-06-20 19:45:27 +02:00
Muhammad Salem 8452824611 release: add missing link for win opencl adreno arm64 (#24809) 2026-06-20 23:08:59 +08:00
Matti4 e27f308597 server: avoid forwarding auth headers in CORS proxy (#24373)
* server: avoid forwarding auth headers in CORS proxy

* format

* fix test

* fix e2e test

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-20 15:34:47 +02:00
Aldehir Rojas 67e9fd3b74 docker : prebuild web UI for s390x build [no release] (#24829) 2026-06-20 05:54:42 -05:00
davidrhodus 796f41bedc model : glm-dsa load DSA indexer tensors as optional (#24770)
GLM-5.2 ships the DSA "lightning indexer" on only a subset of layers (the
"full" layers; others omit it), but the GLM_DSA loader created the five
indexer tensors on every layer as required, so loading any GLM-5.2 GGUF
failed with e.g. `missing tensor 'blk.3.indexer.k_norm.weight'`.

GLM_DSA's graph is llama_model_deepseek2::graph (plain MLA) and does not use
the indexer tensors (indexer runtime not yet implemented), so they are
loaded-but-unused. Marking them TENSOR_NOT_REQUIRED lets layers without an
indexer load as nullptr and the model runs as full MLA attention.

DeepSeek-V3.2 (uniform indexer on all layers) is unaffected.
2026-06-20 13:48:24 +03:00
Adrien Gallouët 37a77fb057 ggml : optimize AMX (#24806)
Flatten the partition over n_batch * M so every thread participates in
the quantization

    | CPU                             | Model                         | Test   |   t/s OLD |   t/s NEW |   Speedup |
    |:--------------------------------|:------------------------------|:-------|----------:|----------:|----------:|
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw  | pp512  |    730.71 |    779.86 |      1.07 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_NL - 4.5 bpw  | tg128  |     87.88 |     86.79 |      0.99 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512  |    725.09 |   1023.31 |      1.41 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128  |     83.64 |     83.62 |      1.00 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0              | pp512  |    820.51 |    924.05 |      1.13 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_0              | tg128  |     90.59 |     92.46 |      1.02 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1              | pp512  |    776.88 |    872.79 |      1.12 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_1              | tg128  |     89.39 |     90.94 |      1.02 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M            | pp512  |    719.28 |   1009.27 |      1.40 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_M            | tg128  |     80.62 |     80.86 |      1.00 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S            | pp512  |    732.29 |   1077.29 |      1.47 |
    | Intel(R) Xeon(R) Platinum 8488C | qwen35 0.8B Q4_K_S            | tg128  |     86.42 |     83.53 |      0.97 |

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-20 13:43:06 +03:00
Sigbjørn Skjæret f4043fec01 convert : more consistent handling of rope_parameters (#24833) 2026-06-20 13:42:36 +03:00
Masashi Yoshimura f449e05537 ggml-webgpu: add adapter toggles for F16 on Vulkan + NVIDIA 2026-06-20 08:12:32 +09:00
Xuan-Son Nguyen 2b686a9120 server: refactor child --> router communication (#24821)
* server: refactor child --> router communication

* fix wakeup case

* add docs

* improve update_status()

* nits
2026-06-20 01:02:26 +02:00
Adrien Gallouët 4b48a53b6c server : optimize get_token_probabilities (#24796)
Use std::partial_sort to order only the requested top-n tokens instead
of the full vocabulary

    logprobs sort: vocab=128000 n_top=0 iters=100
    full    sort:   8555.6 us/op
    partial sort:    704.3 us/op

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-19 23:26:54 +02:00
Xuan-Son Nguyen e475fa2b5f mtmd, arg: fix utf8 handling on windows (#24779)
* mtmd, arg: fix utf8 handling on windows

* also fix ggml_fopen

* fix build fail

* also fix CLI
2026-06-19 22:28:38 +02:00
Xuan-Son Nguyen 175147e8f6 server: remove all internal mentions about "webui" (#24817) 2026-06-19 22:12:46 +02:00
Mikolaj Kucharski fabde3bf51 arg: Add comment line support to --api-key-file (#23168) 2026-06-19 17:33:54 +02:00
Alessandro de Oliveira Faria (A.K.A.CABELO) 0d2d9ccbf6 vendor : update cpp-httplib to 0.48.0 (#24787) 2026-06-19 22:16:35 +08:00
Xuan-Son Nguyen 8c2d6f6475 server: add --agent arg, remove redundant webui naming compat (#24801)
* server: add --agent arg, remove redundant webui naming compat

* corrent env

* fix the test

* llama-gen-docs

* nits: wordings
2026-06-19 16:06:13 +02:00
Aldehir Rojas 38724ab593 docker : build the UI (#24794)
* docker : build the UI

* cont : use existing APP_VERSION
2026-06-19 15:32:31 +02:00
Xuan-Son Nguyen e2e7a9b2d0 mtmd: several bug fixes (#24784)
* mtmd: several bug fixes

* fix build

* fix gemma4ua

* add sanity check in get_u32()

* fix build (2)

* area() avoid overflow
2026-06-19 12:18:36 +02:00
Ruixiang Wang b14e3fb90c spec: support eagle3 for qwen3.5 & 3.6 (#24593)
* spec: support qwen3.5 & 3.6 eagle3 draft

* eagle3: Add deferred boundary checkpoints restore support for hybrid models

* apply suggestions

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* spec: adapt to API change

* spec: fix naming

* cont : add TODO

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-19 13:08:50 +03:00
Xuan-Son Nguyen 159d093a43 server: fix non-bound n_discard value (ctx shifting) (#24786)
* server: fix non-bound n_discard value

* Update tools/server/server-context.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-19 10:53:44 +02:00
Georgi Gerganov 5fd2dc2c41 sync : ggml 2026-06-19 10:19:14 +03:00
Georgi Gerganov 1868af13ac ggml : bump version to 0.15.2 (ggml/1548) 2026-06-19 10:19:14 +03:00
Georgi Gerganov 5bd21b8555 pi : remove docs from system prompt (#24791) 2026-06-19 09:34:00 +03:00
Georgi Gerganov 80452d65b9 server : consolidate slot selection into get_available_slot (#24755)
Absorb get_slot_by_id logic into get_available_slot so slot selection
is handled by a single function call. When a specific slot id is
requested, the LCP similarity check still runs to enable proper
prompt cache updates.

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-19 09:22:34 +03:00
shalinib-ibm 8141e730f1 ggml-cpu: support K tails in power10 Q8/Q4 MMA matmul (#24753)
* ggml-cpu: support K tails in Power10 MMA Q8/Q4 matmul

This patch removes the requirement that K be divisible by kc in the tinyBlas_Q0_PPC tiled matmul path. Process the final K panel using its actual depth and pass the reduced panel size through packing and kernel execution.  This allows more workloads to use the MMA kernel and reduces fallback to mnpack.

* Apply suggestion from @taronaeo

Co-authored-by: Aaron Teo <taronaeo@gmail.com>

---------

Co-authored-by: Aaron Teo <taronaeo@gmail.com>
2026-06-19 08:55:38 +03:00
Xuan-Son Nguyen db52540f73 mtmd: add batching support for internvl (#24775) 2026-06-19 01:16:16 +02:00
Pascal 3a3edc9ac6 Ggml/cuda col2im 1d (#24417)
* cuda: add GGML_OP_COL2IM_1D, follow-up to the CPU op

* cuda: col2im_1d use fast_div_modulo for the index decomposition

* cuda: col2im_1d tighten supports_op, type match and contiguous dst
2026-06-18 22:23:01 +02:00
Reguna 40f3aafc45 server: add "X-Accel-Buffering": "no" header to streaming endpoints (#24774)
* server: add "X-Accel-Buffering": "no" header to streaming endpoints

This header tells Nginx (as a reverse proxy) to NOT buffer responses. (only affects streaming endpoints)
Without it, Nginx will break streaming with certain applications (notably the Pi coding harness).
2026-06-18 22:01:24 +02:00
Xuan-Son Nguyen a6b3260a42 mtmd: add batching for mtmd-cli, add video tests (#24778) 2026-06-18 21:55:04 +02:00
o7si 32eddaf2ea cmake : fix ui build with read-only source (#24752) 2026-06-18 18:59:18 +02:00
Xuan-Son Nguyen 060ce1bf72 mtmd: refactor llava-uhd overview image handling (always use ov_img_first) (#24769)
* add dedicated "overview" for mtmd_image_preproc_out

* corrections

* correct (again)

* nits

* nits (2)
2026-06-18 18:53:49 +02:00
Max Krasnyansky d2c67959b3 hexagon: support for op-trace (fine-grain tracing of HVX/HMX/DMA events) (#24592)
* hex-optrace: add support for optrace and instrument matmul and flash-atten code

* hex-trace: improve trace event and prefetto generator

* hex-trace: add new script dedicated to handling traces, specifically perfetto traces

* hex-trace: add --head/--tail options to profile and trace tools

* hex-trace: fix whitespaces

* hex-trace: fix flake8 warnings

* hex-trace: fix flake8 warnings

* hmx-fa: restore q_tiles clearing

* hex-profile: remove circular dep in includes

* hex-trace: simplify trace sizing check

* hex-profile: sort events in the summary by name
2026-06-18 08:35:02 -07:00
Kangjia Gao 7b6c5a2aed docs: fix export-lora --lora-scaled syntax [no release] (#24703)
Assisted-by: Codex
2026-06-18 16:46:17 +02:00
Xuan-Son Nguyen fe7c8b2414 server: (router) fix stopping_thread potentially hang (#24728)
* server: (router) fix stopping_thread potentially hang

* fix windows build
2026-06-18 15:41:09 +02:00
Xuan-Son Nguyen e1efd0991d server: add "schema" and validation (#24150)
* wip

* working

* correct some limits

* add field name to error message
2026-06-18 15:40:58 +02:00
Aarni Koskela 08023072ef server : add last-5-seconds generation speed display (#24291)
* server : add last-5-seconds generation speed display

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-18 14:02:20 +02:00
Amos Wong 20832179e2 ui: provide touch accessible model selection UI (#24604)
* ui : add model selector storybook stories

Covers list, favorites, single-model, all status states
(loading/loaded/sleeping/failed/idle), and selection states.

* ui : improve model selector mobile UX with hover media queries

Use @media (hover:none) to show action buttons directly on touch
devices and color-code them by model status (amber=sleeping,
green=loaded, muted=idle). Status dots hidden on touch. Desktop
hover behavior unchanged.
2026-06-18 13:14:20 +02:00
Anuj Attri 10786217e9 server : return HTTP 400 on invalid grammar (#24144) (#24154)
Throw on grammar parse failure so the server returns HTTP 400
instead of silently dropping the constraint.
Add a regression test for the invalid-grammar response.

Fixes #24144
2026-06-18 12:49:14 +02:00
Xuan-Son Nguyen 552258c535 server: (router) rework -hf preset repo (#24739)
* server: temporary remove HF remote preset

* rework remove preset.ini support

* rm unused get_remote_preset_whitelist()

* print warning

* add docs

* rm stray file
2026-06-18 12:45:23 +02:00
Xuan-Son Nguyen 968c43891a server: fix router args not being forwarded to child instances (#24760) 2026-06-18 12:15:46 +02:00
Xuan-Son Nguyen 24bba7b98e mtmd: refactor preprocessor, add mtmd_image_preproc_out (#24736)
* add mtmd_image_preproc_out

* add dev docs

* remove unused clip API

* rm unused clip_image_f32_batch::grid

* change preprocess() call signature
2026-06-18 12:04:39 +02:00
Neo Zhang 9724f664e8 [SYCL] rename GGML_SYCL_SUPPORT_LEVEL_ZERO (#24719)
* rename GGML_SYCL_SUPPORT_LEVEL_ZERO to GGML_SYCL_SUPPORT_LEVEL_ZERO_API, and GGML_SYCL_ENABLE_LEVEL_ZERO to  GGML_SYCL_USE_LEVEL_ZERO_API

* fix code format

* fix error when rebase
2026-06-18 11:18:26 +03:00
Neo Zhang dd69db2924 sycl : support MUL_MAT and OUT_PROD with Q1_0 (#24721) 2026-06-18 11:17:37 +03:00
Adrien Gallouët 6ec59ddaea app : enable self-update only when built with llama-install.sh (#24754)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-18 09:57:59 +02:00
Sigbjørn Skjæret 32e806b9c1 ci : fix check-release message parsing (#24751) 2026-06-18 09:32:56 +02:00
Neo Zhang 6f1034b32a [SYCL] support OPs: conv_2d, conv_2d_dw, conv2d_transpose (#24600)
* fix conflict

* fix format issue, rename

* rm debug code

* correct the file name
2026-06-18 09:40:03 +03:00
Aleksander Grygier 0b73fc79fe ui: Update code formatting command in pre-commit hook (#24685) 2026-06-18 08:33:50 +02:00
Ravi Panchumarthy 4a79037b8b ci : fix Windows x64 (OpenVINO) release link (#24731) 2026-06-18 08:30:08 +02:00
Georgi Gerganov cae0a3b0b0 metal : check for BF16 support in concat kernel (#24747) 2026-06-18 09:16:06 +03:00
Xuan-Son Nguyen f3e1828164 mtmd: llava_uhd should no longer use batch dim (#24732) 2026-06-17 22:40:50 +02:00
shalinib-ibm 2e88c49c90 ggml-cpu: Conditionally enable power11 backend based on compiler support (#24687)
* ggml: Conditionally enable power11 backend based on compiler support

Guard POWER11 backend creation behind a compiler flag check for -mcpu=power11. This avoids build failures on current GCC/Clang toolchains while preserving forward compatibility once POWER11 support becomes available.

* Update CMakeLists.txt

ggml-cpu: Use -mcpu=power10 for P10 and P11
2026-06-18 02:45:19 +08:00
Georgi Gerganov 0843245cb1 metal : implement rope_back operator (#24725)
Reuse existing rope kernels with a function constant to toggle forward/backward
rotation, avoiding duplicate kernel code.

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-17 20:36:05 +03:00
Georgi Gerganov 8d2e580632 metal : add f16 and bf16 support for concat operator (#24724)
* metal : add f16 and bf16 support for concat operator

Extend the Metal backend concat operator to support f16 and bf16 tensor
types in addition to the existing f32 and i32 support.

- Template kernel_concat on type T with specializations for float, half,
  bfloat, and int
- Add type-specific pipeline getter ggml_metal_library_get_pipeline_concat()
- Update device support check to allow f16 unconditionally and bf16 when
  device supports bfloat16
- Update dispatch to select the correct kernel specialization by type

Assisted-by: pi:llama.cpp/Qwen3.6-27B

* metal : extend concat operator to support f16, bf16, i8, i16 and i64

Assisted-by: pi:llama.cpp/Qwen3.6-27B
2026-06-17 19:38:55 +03:00
Xuan-Son Nguyen 4b4d13ae72 server: (router) add model management API (#23976)
* wip

* server: (router) add SSE realtime updates API

* nits

* wip

* add download API

* add download api

* update docs

* add delete endpoint

* fix std::terminate

* fix crash

* fix 2

* add tests

* nits
2026-06-17 18:04:58 +02:00
Dev-iL b4024af6c2 llama : skip main_gpu validation when no devices are available (#23405) 2026-06-17 17:30:26 +03:00
Ruixiang Wang 1a2dea29b9 spec: fix segfault error on long prompts for eagle3 (#24707) 2026-06-17 17:29:49 +03:00
Neo Zhang 74a80dd9c0 [SYCL] add dev2dev memcpy by SYCL API (#24476)
* add dev2dev memcpy by SYCL API

* mv GGML_SYCL_DEV2DEV_MEMCPY to runntime table

* update the detect method for p2p comm

* fix the erro created during fix confilct

---------

Co-authored-by: Neo Zhang <NA>
2026-06-17 17:21:34 +03:00
Neo Zhang d1759e4156 [SYCL] Add conv_3d (#24691)
* add conv_3d

* optimize

* update ops.md

* restore test script

* rm unused code

* rm copyright notes
2026-06-17 17:20:01 +03:00
Julien Chaumond 8086439a4c webui: export conversations as jsonl (#24688)
* webui: export conversations as jsonl

each session is one jsonl file, a session header line followed by one line per message
exporting multiple conversations bundles them into a zip, one jsonl file each

* webui: import jsonl and zip conversation exports

parse the new jsonl session format and zip archives on import
keep supporting the legacy json format
2026-06-17 13:25:47 +02:00
Winston Ma 558e221b70 vulkan: record actual memory properties during buffer creation (#24326) 2026-06-17 11:14:48 +02:00
Ruben Ortlam ea21e03955 Revert "cuda: reset cuda context after reading memory size (#23935)" (#24715)
This reverts commit 0f7fada56b.
2026-06-17 10:59:35 +02:00
225 changed files with 21511 additions and 15829 deletions
+16
View File
@@ -13,6 +13,20 @@ ARG APP_REVISION=N/A
# BUILD STAGE
# Compile all binary files and libraries
# ==============================================================================
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM ${CANN_BASE_IMAGE} AS build
# -- Install build dependencies --
@@ -26,6 +40,8 @@ WORKDIR /app
# -- Copy project files --
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
# -- Set CANN environment variables (required for compilation) --
# Using ENV instead of `source` allows environment variables to persist across the entire image layer
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
+16
View File
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
ARG TARGETARCH
@@ -16,6 +30,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \
else \
+16
View File
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
ARG GCC_VERSION
@@ -26,6 +40,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
fi && \
+16
View File
@@ -5,6 +5,20 @@ ARG APP_REVISION=N/A
## Build Image
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM docker.io/intel/deep-learning-essentials:$ONEAPI_VERSION AS build
ARG GGML_SYCL_F16=ON
@@ -22,6 +36,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \
echo "GGML_SYCL_F16 is set" \
&& export OPT_SYCL_F16="-DGGML_SYCL_F16=ON" \
+16
View File
@@ -10,6 +10,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM ${BASE_MUSA_DEV_CONTAINER} AS build
# MUSA architecture to build for (defaults to all supported archs)
@@ -29,6 +43,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \
export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \
fi && \
+16
View File
@@ -22,6 +22,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
## Build Image
FROM docker.io/ubuntu:${UBUNTU_VERSION} AS build
@@ -69,6 +83,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
# Build Stage
RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \
cmake -B build/ReleaseOV -G Ninja \
+16
View File
@@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
### Build image
FROM ${BASE_ROCM_DEV_CONTAINER} AS build
@@ -38,6 +52,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
cmake -S . -B build \
-DGGML_HIP=ON \
+16
View File
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
# Install build tools
@@ -17,6 +31,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \
cmake --build build --config Release -j$(nproc)
+16
View File
@@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A
ARG APP_VERSION=N/A
ARG APP_REVISION=N/A
ARG NODE_VERSION=24
FROM docker.io/node:$NODE_VERSION AS web
ARG APP_VERSION
WORKDIR /app/tools/ui
COPY tools/ui/package.json tools/ui/package-lock.json ./
RUN npm ci
COPY tools/ui/ ./
RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build
FROM docker.io/ubuntu:$UBUNTU_VERSION AS build
RUN apt-get update && \
@@ -14,6 +28,8 @@ WORKDIR /app
COPY . .
COPY --from=web /app/tools/ui/dist tools/ui/dist
RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \
cmake --build build -j $(nproc)
+2
View File
@@ -10,6 +10,8 @@
build*/
tools/ui/node_modules/
models/*
/llama-cli
+16 -2
View File
@@ -58,6 +58,13 @@ jobs:
git tag ${{ steps.srctag.outputs.name }} || exit 0
git push origin ${{ steps.srctag.outputs.name }} || exit 0
build_ui:
name: Build UI
needs: create_tag
uses: ./.github/workflows/ui-build.yml
with:
hf_ui_version: ${{ needs.create_tag.outputs.source_tag }}
prepare_matrices:
name: Prepare Docker matrices
runs-on: ubuntu-24.04
@@ -79,7 +86,7 @@ jobs:
[
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
@@ -135,7 +142,7 @@ jobs:
push_to_registry:
name: Push Docker image to Docker Registry
needs: [prepare_matrices, create_tag]
needs: [prepare_matrices, create_tag, build_ui]
runs-on: ${{ matrix.config.runs_on }}
strategy:
@@ -150,6 +157,13 @@ jobs:
fetch-depth: 0
ref: ${{ needs.create_tag.outputs.source_tag }}
- name: Download prebuilt UI
if: ${{ matrix.config.prebuilt_ui == true }}
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with:
name: ui-build
path: tools/ui/dist
- name: Set up QEMU
if: ${{ contains(matrix.config.platforms, 'linux/amd64') }}
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4
+5 -1
View File
@@ -46,11 +46,13 @@ jobs:
steps:
- id: check
env:
COMMIT_MESSAGE: ${{ github.event.head_commit.message }}
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
echo "should_release=true" >> $GITHUB_OUTPUT
elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/master" ]]; then
if echo "${{ github.event.head_commit.message }}" | grep -q '\[no release\]'; then
if echo "$COMMIT_MESSAGE" | grep -q '\[no release\]'; then
echo "should_release=false" >> $GITHUB_OUTPUT
else
echo "should_release=true" >> $GITHUB_OUTPUT
@@ -542,6 +544,7 @@ jobs:
steps:
- name: Set OpenVINO version output
id: openvino_version
shell: bash
run: echo "value=${{ env.OPENVINO_VERSION_MAJOR }}" >> $GITHUB_OUTPUT
- name: Clone
@@ -1624,6 +1627,7 @@ jobs:
**Windows:**
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
- [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip)
- [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
- [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip)
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
-10
View File
@@ -25,13 +25,3 @@ Commits:
- Do not explicitly set the git author in commits - rely on the default git config
- Always use `--no-gpg-sign` when committing
- Never `git push` without explicit confirmation from the user
Resources (read on demand):
- [CONTRIBUTING.md](CONTRIBUTING.md)
- [Build documentation](docs/build.md)
- [Server usage documentation](tools/server/README.md)
- [Server development documentation](tools/server/README-dev.md)
- [PEG parser](docs/development/parsing.md)
- [Auto parser](docs/autoparser.md)
- [Jinja engine](common/jinja/README.md)
- [PR template](.github/pull_request_template.md)
+26 -13
View File
@@ -20,16 +20,21 @@ int llama_fit_params(int argc, char ** argv);
int llama_quantize(int argc, char ** argv);
int llama_perplexity(int argc, char ** argv);
// hands the update over to the install script, which downloads and swaps the binary
// Self-update is only supported for binaries built with llama-install.sh
static int llama_update(int argc, char ** argv) {
(void) argc;
(void) argv;
#ifdef LLAMA_INSTALL_BUILD
#if defined(_WIN32)
return system("powershell -NoProfile -ExecutionPolicy Bypass -Command \"irm https://llama.app/install.ps1 | iex\"");
#else
return system("curl -fsSL https://llama.app/install.sh | sh");
#endif
#else
printf("Updates are available only when installed from https://llama.app\n");
return 1;
#endif
}
static const char * progname;
@@ -46,21 +51,29 @@ struct command {
int (*func)(int, char **);
};
#ifdef LLAMA_INSTALL_BUILD
#define UPDATE_HIDDEN false
#else
#define UPDATE_HIDDEN true
#endif
static const command cmds[] = {
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"update", "Update llama to the latest release", {}, false, llama_update },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
};
#undef UPDATE_HIDDEN
static int version(int argc, char ** argv) {
printf("%s\n", llama_build_info());
return 0;
+91 -133
View File
@@ -17,6 +17,7 @@
# define NOMINMAX
#endif
#include <windows.h>
#include <shellapi.h>
#endif
#define JSON_ASSERT GGML_ASSERT
@@ -285,58 +286,15 @@ static std::string clean_file_name(const std::string & fname) {
return clean_fname;
}
static bool common_params_handle_remote_preset(common_params & params, llama_example ex) {
GGML_ASSERT(!params.model.hf_repo.empty());
// the returned hf_repo is without tag
auto [hf_repo, hf_tag] = common_download_split_repo_tag(params.model.hf_repo);
// "latest" tag (default if not specified) is translated to "default" preset
if (hf_tag == "latest") {
hf_tag = "default";
}
std::string model_endpoint = common_get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";
// prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
LOG_TRC("%s: looking for remote preset at %s\n", __func__, preset_url.c_str());
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;
// remote preset is optional, so we don't error out if not found
if (has_preset) {
LOG_TRC("%s: applying remote preset from %s\n", __func__, preset_url.c_str());
common_preset_context ctx(ex, /* only_remote_allowed */ true);
common_preset global;
auto remote_presets = ctx.load_from_ini(preset_path, global);
remote_presets = ctx.cascade(global, remote_presets);
if (remote_presets.find(hf_tag) != remote_presets.end()) {
common_preset preset = remote_presets.at(hf_tag);
LOG_INF("\n%s", preset.to_ini().c_str()); // to_ini already added trailing newline
preset.apply_to_params(params);
} else {
throw std::runtime_error("Remote preset.ini does not contain [" + std::string(hf_tag) + "] section");
}
} else {
LOG_TRC("%s: no remote preset found, skipping\n", __func__);
}
return has_preset;
}
struct handle_model_result {
bool found_mmproj = false;
common_params_model mmproj;
bool found_mtp = false;
common_params_model mtp;
bool found_preset = false;
std::string preset_path;
};
static handle_model_result common_params_handle_model(struct common_params_model & model,
@@ -345,7 +303,6 @@ static handle_model_result common_params_handle_model(struct common_params_model
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
model.name = model.docker_repo;
} else if (!model.hf_repo.empty()) {
// If -m was used with -hf, treat the model "path" as the hf_file to download
if (model.hf_file.empty() && !model.path.empty()) {
@@ -355,11 +312,16 @@ static handle_model_result common_params_handle_model(struct common_params_model
common_download_opts hf_opts = opts;
auto download_result = common_download_model(model, hf_opts);
if (!download_result.preset_path.empty()) {
result.found_preset = true;
result.preset_path = download_result.preset_path;
return result; // skip everything else if preset.ini is used
}
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
}
model.name = model.hf_repo;
model.path = download_result.model_path;
if (!download_result.mmproj_path.empty()) {
@@ -454,6 +416,17 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
try {
auto res = common_params_handle_model(params.model, opts);
if (res.found_preset) {
if (!params.models_preset.empty()) {
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
}
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
params.models_preset = res.preset_path;
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
return true;
}
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -601,30 +574,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// parse the first time to get -hf option (used for remote preset)
parse_cli_args();
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// maybe handle remote preset
if (!params.model.hf_repo.empty() && !skip_model_download) {
std::string cli_hf_repo = params.model.hf_repo;
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
// special case: if hf_repo explicitly set by preset, we need to preserve it (ignore CLI value)
// this is useful when we have one HF repo pointing to other HF repos (one model - multiple GGUFs)
std::string preset_hf_repo = params.model.hf_repo;
bool preset_has_hf_repo = preset_hf_repo != cli_hf_repo;
if (has_preset) {
// re-parse CLI args to override preset values
parse_cli_args();
}
// preserve hf_repo from preset if needed
if (preset_has_hf_repo) {
params.model.hf_repo = preset_hf_repo;
}
}
postprocess_cpu_params(params.cpuparams, nullptr);
postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
@@ -635,15 +584,21 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// handle model and download
if (!skip_model_download) {
common_params_handle_models(params, ctx_arg.ex);
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
throw std::invalid_argument("error: --model is required\n");
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
}
}
if (params.escape) {
@@ -937,7 +892,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
return true;
}
#ifdef _WIN32
struct utf8_argv {
std::vector<std::string> buf;
std::vector<char*> ptrs;
};
static utf8_argv make_utf8_argv() {
utf8_argv out;
int wargc = 0;
LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc);
if (!wargv) return out;
out.buf.reserve(wargc);
for (int i = 0; i < wargc; ++i) {
int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr);
if (n <= 0) { out.buf.emplace_back(); continue; }
auto& s = out.buf.emplace_back();
s.resize(static_cast<size_t>(n - 1));
(void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr);
}
LocalFree(wargv);
out.ptrs.reserve(out.buf.size() + 1);
for (auto& s : out.buf) out.ptrs.push_back(s.data());
out.ptrs.push_back(nullptr);
return out;
}
#endif
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
#ifdef _WIN32
auto utf8 = make_utf8_argv();
// repair argv only when it matches the process command line
if (static_cast<int>(utf8.buf.size()) == argc) {
argv = utf8.ptrs.data();
}
#endif
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
const common_params params_org = ctx_arg.params; // the example can modify the default params
@@ -2874,62 +2866,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.api_prefix = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
// Deprecated: use --ui-config instead (kept for backward compat)
add_opt(common_arg(
{"--webui-config"}, "JSON",
"[DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)",
[](common_params & params, const std::string & value) {
params.ui_config_json = value;
params.webui_config_json = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG"));
add_opt(common_arg(
{"--ui-config"}, "JSON",
{"--ui-config", "--webui-config"}, "JSON",
"JSON that provides default UI settings (overrides UI defaults)",
[](common_params & params, const std::string & value) {
params.ui_config_json = value;
params.webui_config_json = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG"));
// Deprecated: use --ui-config-file instead (kept for backward compat)
add_opt(common_arg(
{"--webui-config-file"}, "PATH",
"[DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)",
[](common_params & params, const std::string & value) {
params.ui_config_json = read_file(value);
params.webui_config_json = params.ui_config_json;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
add_opt(common_arg(
{"--ui-config-file"}, "PATH",
{"--ui-config-file", "--webui-config-file"}, "PATH",
"JSON file that provides default UI settings (overrides UI defaults)",
[](common_params & params, const std::string & value) {
params.ui_config_json = read_file(value);
params.webui_config_json = params.ui_config_json;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG_FILE"));
// Deprecated: use --ui-mcp-proxy instead (kept for backward compat)
add_opt(common_arg(
{"--webui-mcp-proxy"},
{"--no-webui-mcp-proxy"},
"[DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy",
[](common_params & params, bool value) {
params.ui_mcp_proxy = value;
params.webui_mcp_proxy = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
add_opt(common_arg(
{"--ui-mcp-proxy"},
{"--no-ui-mcp-proxy"},
{"--ui-mcp-proxy", "--webui-mcp-proxy"},
{"--no-ui-mcp-proxy", "--no-webui-mcp-proxy"},
"experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)",
[](common_params & params, bool value) {
params.ui_mcp_proxy = value;
params.webui_mcp_proxy = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_MCP_PROXY"));
add_opt(common_arg(
@@ -2941,24 +2897,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.server_tools = parse_csv_row(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
// Deprecated: use --ui/--no-ui instead (kept for backward compat)
add_opt(common_arg(
{"--webui"},
{"--no-webui"},
"[DEPRECATED: use --ui/--no-ui] whether to enable the Web UI",
{"-ag", "--agent"},
{"-no-ag", "--no-agent"},
"whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)",
[](common_params & params, bool value) {
params.ui = value;
params.webui = value;
if (value) {
params.server_tools = {"all"};
params.ui_mcp_proxy = true;
} else {
params.server_tools.clear();
params.ui_mcp_proxy = false;
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI"));
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_AGENT"));
add_opt(common_arg(
{"--ui"},
{"--no-ui"},
{"--ui", "--webui"},
{"--no-ui", "--no-webui"},
string_format("whether to enable the Web UI (default: %s)", params.ui ? "enabled" : "disabled"),
[](common_params & params, bool value) {
params.ui = value;
params.webui = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI"));
add_opt(common_arg(
@@ -2989,7 +2947,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
add_opt(common_arg(
{"--api-key-file"}, "FNAME",
"path to file containing API keys (default: none)",
"path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)",
[](common_params & params, const std::string & value) {
std::ifstream key_file(value);
if (!key_file) {
@@ -2997,7 +2955,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
std::string key;
while (std::getline(key_file, key)) {
if (!key.empty()) {
if (!key.empty() && key[0] != '#') {
params.api_keys.push_back(key);
}
}
+5 -4
View File
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
arguments.name_suffix) +
arguments.value_prefix +
(schema_info.resolves_to_string(param_schema) ?
p.tool_arg_string_value(until_suffix) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
p.ac(p.tool_arg_string_value(until_suffix) +
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
(p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.tool_arg_close(p.literal(arguments.value_suffix)))));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
+15 -1
View File
@@ -1074,6 +1074,18 @@ std::vector<common_file_info> fs_list(const std::string & path, bool include_dir
return files;
}
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) {
#ifdef _WIN32
int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0);
if (!wlen) { return std::ifstream(); }
std::vector<wchar_t> wfname(wlen);
(void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen);
return std::ifstream(wfname.data(), mode);
#else
return std::ifstream(fname, mode);
#endif
}
//
// TTY utils
//
@@ -2034,7 +2046,7 @@ bool common_prompt_batch_decode(
}
size_t common_prompt_checkpoint::size() const {
return data_tgt.size() + data_dft.size();
return data_tgt.size() + data_dft.size() + data_spec.size();
}
bool common_prompt_checkpoint::empty() const {
@@ -2049,6 +2061,7 @@ void common_prompt_checkpoint::clear() {
data_tgt.clear();
data_dft.clear();
data_spec.clear();
}
void common_prompt_checkpoint::update_pos(
@@ -2138,4 +2151,5 @@ void common_prompt_checkpoint::clear_tgt() {
void common_prompt_checkpoint::clear_dft() {
data_dft.clear();
data_spec.clear();
}
+23 -12
View File
@@ -295,7 +295,16 @@ struct common_params_model {
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string docker_repo = ""; // Docker repo // NOLINT
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
std::string get_name() {
if (!hf_repo.empty()) {
return hf_repo;
}
if (!docker_repo.empty()) {
return docker_repo;
}
return path;
}
};
// draft-model-based speculative decoding parameters
@@ -363,7 +372,7 @@ struct common_params_speculative {
uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
});
return needs_rs_seq ? draft.n_max : 0u;
@@ -624,12 +633,6 @@ struct common_params {
// UI configs
bool ui = true;
// Deprecated: use ui, ui_mcp_proxy, ui_config_json instead
bool webui = ui;
bool webui_mcp_proxy = false;
std::string webui_config_json;
bool ui_mcp_proxy = false;
std::string ui_config_json;
@@ -642,10 +645,11 @@ struct common_params {
std::vector<std::string> server_tools;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_dir = ""; // directory containing models for the router server
std::string models_preset = ""; // directory containing model presets for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
std::string models_preset_hf = ""; // show a warning about remote presets on router loaded (if not empty)
bool log_json = false;
@@ -847,6 +851,9 @@ struct common_file_info {
};
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
// fs open, also handle UTF8 on Windows
std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode);
//
// TTY utils
//
@@ -1064,6 +1071,10 @@ struct common_prompt_checkpoint {
std::vector<uint8_t> data_tgt;
std::vector<uint8_t> data_dft;
// (optional) speculative-decoding implementation state stashed with the checkpoint
// (e.g. eagle3's deferred-boundary g_embd row)
std::vector<uint8_t> data_spec;
size_t size() const;
bool empty() const;
+120 -17
View File
@@ -696,6 +696,7 @@ struct hf_plan {
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
hf_cache::hf_file preset; // if set, only this file is downloaded
};
static hf_plan get_hf_plan(const common_params_model & model,
@@ -717,6 +718,14 @@ static hf_plan get_hf_plan(const common_params_model & model,
return plan;
}
// if preset.ini exists in the repo root, download only that file
for (const auto & f : all) {
if (f.path == "preset.ini") {
plan.preset = f;
return plan;
}
}
hf_cache::hf_file primary;
if (!model.hf_file.empty()) {
@@ -794,14 +803,19 @@ common_download_model_result common_download_model(const common_params_model &
if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
@@ -835,17 +849,22 @@ common_download_model_result common_download_model(const common_params_model &
}
if (is_hf) {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.preset.path.empty()) {
// if preset.ini is used, do not set other paths
result.preset_path = hf_cache::finalize_file(hf.preset);
} else {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
}
} else {
result.model_path = model.path;
@@ -997,3 +1016,87 @@ std::vector<common_cached_model_info> common_list_cached_models() {
return result;
}
bool common_download_remove(const std::string & hf_repo_with_tag) {
namespace fs = std::filesystem;
auto [repo_id, tag] = common_download_split_repo_tag(hf_repo_with_tag);
if (tag.empty()) {
return hf_cache::remove_cached_repo(repo_id);
}
std::string tag_upper = tag;
for (char & c : tag_upper) {
c = (char) std::toupper((unsigned char) c);
}
auto files = hf_cache::get_cached_files(repo_id);
if (files.empty()) {
return false;
}
// collect snapshot entries whose tag matches
std::vector<fs::path> to_remove;
for (const auto & f : files) {
auto split = get_gguf_split_info(f.path);
if (split.tag == tag_upper) {
to_remove.emplace_back(f.local_path);
}
}
if (to_remove.empty()) {
return false;
}
// resolve blob paths from symlinks before deleting snapshot entries
std::vector<fs::path> blobs_to_check;
for (const auto & p : to_remove) {
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
blobs_to_check.push_back((p.parent_path() / target).lexically_normal());
}
}
}
// remove snapshot entries
for (const auto & p : to_remove) {
std::error_code ec;
fs::remove(p, ec);
if (ec) {
LOG_WRN("%s: failed to remove %s: %s\n", __func__, p.string().c_str(), ec.message().c_str());
}
}
if (blobs_to_check.empty()) {
return true;
}
// collect blobs still referenced by remaining snapshot entries
std::unordered_set<std::string> still_referenced;
for (const auto & f : hf_cache::get_cached_files(repo_id)) {
fs::path p(f.local_path);
std::error_code ec;
if (fs::is_symlink(p, ec)) {
auto target = fs::read_symlink(p, ec);
if (!ec) {
still_referenced.insert((p.parent_path() / target).lexically_normal().string());
}
}
}
// remove orphaned blobs
for (const auto & blob : blobs_to_check) {
if (still_referenced.find(blob.string()) == still_referenced.end()) {
std::error_code ec;
fs::remove(blob, ec);
if (ec) {
LOG_WRN("%s: failed to remove blob %s: %s\n", __func__, blob.string().c_str(), ec.message().c_str());
}
}
}
return true;
}
+8
View File
@@ -63,6 +63,7 @@ struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
std::string preset_path;
};
// throw if the file is missing or invalid (e.g. ETag check failed)
@@ -115,3 +116,10 @@ int common_download_file_single(const std::string & url,
// resolve and download model from Docker registry
// return local path to downloaded model file
std::string common_docker_resolve_model(const std::string & docker);
// Remove a cached model from disk
// input format: "user/model" or "user/model:tag"
// - if tag is omitted, removes the entire repo cache directory
// - if tag is present, removes only files matching that tag (and orphaned blobs)
// returns true if anything was removed
bool common_download_remove(const std::string & hf_repo_with_tag);
+15
View File
@@ -495,4 +495,19 @@ std::string finalize_file(const hf_file & file) {
return file.final_path;
}
bool remove_cached_repo(const std::string & repo_id) {
if (!is_valid_repo_id(repo_id)) {
LOG_WRN("%s: invalid repository: %s\n", __func__, repo_id.c_str());
return false;
}
fs::path repo_path = get_repo_path(repo_id);
std::error_code ec;
auto removed = fs::remove_all(repo_path, ec);
if (ec) {
LOG_ERR("%s: failed to remove repo cache %s: %s\n", __func__, repo_path.string().c_str(), ec.message().c_str());
return false;
}
return removed > 0;
}
} // namespace hf_cache
+3
View File
@@ -29,4 +29,7 @@ hf_files get_cached_files(const std::string & repo_id = {});
// Create snapshot path (link or move/copy) and return it
std::string finalize_file(const hf_file & file);
// Remove the entire cached directory for a repo, returns true if removed
bool remove_cached_repo(const std::string & repo_id);
} // namespace hf_cache
+89 -46
View File
@@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) {
const size_t expected_count = this_args.size();
const size_t input_count = args.count();
JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this_args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this_args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this_args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this_args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in '" + name + "'");
}
} else {
auto & default_arg = this_args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
ctx.set_val(param_name, kwarg->val->execute(args.ctx));
} else {
throw std::runtime_error("Not enough arguments provided to '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
}
value macro_statement::execute_impl(context & ctx) {
if (!is_stmt<identifier>(this->name)) {
throw std::runtime_error("Macro name must be an identifier");
}
std::string name = cast_stmt<identifier>(this->name)->val;
const func_handler func = [this, name, &ctx](const func_args & args) -> value {
size_t expected_count = this->args.size();
size_t input_count = args.count();
const func_handler func = [this, name](const func_args & args) -> value {
context macro_ctx(args.ctx); // new scope for macro execution
JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count);
context macro_ctx(ctx); // new scope for macro execution
// bind parameters
for (size_t i = 0; i < expected_count; ++i) {
if (i < input_count) {
if (is_stmt<identifier>(this->args[i])) {
// normal parameter
std::string param_name = cast_stmt<identifier>(this->args[i])->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else if (is_stmt<keyword_argument_expression>(this->args[i])) {
// default argument used as normal parameter
auto kwarg = cast_stmt<keyword_argument_expression>(this->args[i]);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
value param_value = args.get_kwarg_or_pos(param_name, i);
JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str());
macro_ctx.set_val(param_name, param_value);
} else {
throw std::runtime_error("Invalid parameter type in macro '" + name + "'");
}
} else {
auto & default_arg = this->args[i];
if (is_stmt<keyword_argument_expression>(default_arg)) {
auto kwarg = cast_stmt<keyword_argument_expression>(default_arg);
if (!is_stmt<identifier>(kwarg->key)) {
throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'");
}
std::string param_name = cast_stmt<identifier>(kwarg->key)->val;
JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str());
macro_ctx.set_val(param_name, kwarg->val->execute(ctx));
} else {
throw std::runtime_error("Not enough arguments provided to macro '" + name + "'");
}
//std::string param_name = cast_stmt<identifier>(default_args[i])->val;
//JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str());
//macro_ctx.var[param_name] = default_args[i]->execute(ctx);
}
}
bind_parameters(name, this->args, args, macro_ctx);
// execute macro body
JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size());
@@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) {
return mk_val<value_undefined>();
}
value call_statement::execute_impl(context & ctx) {
auto call_expr = cast_stmt<call_expression>(this->call);
if (!call_expr) {
throw std::runtime_error("Call statement requires a valid call expression");
}
value callee_val = call_expr->callee->execute(ctx);
if (!is_val<value_func>(callee_val)) {
throw std::runtime_error("Callee is not a function: got " + callee_val->type());
}
auto * callee_func = cast_val<value_func>(callee_val);
context caller_ctx(ctx); // new scope for caller execution
const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value {
context block_ctx(caller_ctx); // new scope for block execution
bind_parameters("caller", this->caller_args, args, block_ctx);
JJ_DEBUG("Executing call body with %zu statements", this->body.size());
auto res = exec_statements(this->body, block_ctx);
JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str());
return res;
};
context call_ctx(ctx);
call_ctx.set_val("caller", mk_val<value_func>("caller", func));
func_args args(call_ctx);
for (const auto & arg_expr : call_expr->args) {
auto arg_val = arg_expr->execute(ctx);
JJ_DEBUG(" Argument type: %s", arg_val->type().c_str());
args.push_back(arg_val);
}
JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count());
return callee_func->invoke(args);
}
value member_expression::execute_impl(context & ctx) {
value object = this->object->execute(ctx);
+1
View File
@@ -552,6 +552,7 @@ struct call_statement : public statement {
for (const auto & arg : this->caller_args) chk_type<expression>(arg);
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
};
struct ternary_expression : public expression {
+23 -23
View File
@@ -233,27 +233,27 @@ struct BuiltinRule {
};
static std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
{"boolean", {"(\"true\" | \"false\") space", {}}},
{"boolean", {"(\"true\" | \"false\")", {}}},
{"decimal-part", {"[0-9]{1,16}", {}}},
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}},
{"integer", {"(\"-\"? integral-part)", {"integral-part"}}},
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}},
{"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}},
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}},
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
{"null", {"\"null\" space", {}}},
{"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}},
{"null", {"\"null\"", {}}},
};
static std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
{"date-time", {"date \"T\" time", {"date", "time"}}},
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
{"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}},
{"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}},
{"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}}
};
static bool is_reserved_name(const std::string & name) {
@@ -551,16 +551,16 @@ private:
}
return join_seq();
};
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"");
}
/*
Returns a rule that matches a JSON string that is none of the provided strings
not_strings({"a"})
-> ["] ( [a] char+ | [^"a] char* )? ["] space
-> ["] ( [a] char+ | [^"a] char* )? ["]
not_strings({"and", "also"})
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["]
*/
std::string _not_strings(const std::vector<std::string> & strings) {
@@ -619,7 +619,7 @@ private:
if (!trie.is_end_of_string) {
out << "?";
}
out << " [\"] space";
out << " [\"]";
return out.str();
}
@@ -725,7 +725,7 @@ private:
rule += " )?";
}
rule += " \"}\" space";
rule += " space \"}\"";
return rule;
}
@@ -858,14 +858,14 @@ public:
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
}
if (schema.contains("const")) {
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
}
if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")");
}
if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
@@ -933,7 +933,7 @@ public:
}
}
if (!enum_intersection.empty()) {
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")");
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
@@ -948,7 +948,7 @@ public:
}
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
}
rule += " \"]\" space";
rule += " space \"]\"";
return _add_rule(rule_name, rule);
}
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
@@ -956,7 +956,7 @@ public:
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\"");
}
if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
return _visit_pattern(schema["pattern"], rule_name);
@@ -972,7 +972,7 @@ public:
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\"");
}
if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
int64_t min_value = std::numeric_limits<int64_t>::min();
@@ -990,7 +990,7 @@ public:
std::stringstream out;
out << "(";
build_min_max_int(min_value, max_value, out);
out << ") space";
out << ")";
return _add_rule(rule_name, out.str());
}
if (schema.empty() || schema_type == "object") {
+202 -89
View File
@@ -6,13 +6,14 @@
#include "unicode.h"
#include <algorithm>
#include <deque>
#include <initializer_list>
#include <map>
#include <memory>
#include <nlohmann/json.hpp>
#include <regex>
#include <set>
#include <stdexcept>
#include <unordered_set>
// Trick to catch missing branches
template <typename T>
@@ -88,40 +89,7 @@ struct trie {
return match_result{match_result::NO_MATCH};
}
struct prefix_and_next {
std::vector<uint32_t> prefix;
std::vector<uint32_t> next_chars;
};
std::vector<prefix_and_next> collect_prefix_and_next() {
std::vector<uint32_t> prefix;
std::vector<prefix_and_next> result;
collect_prefix_and_next(0, prefix, result);
return result;
}
private:
void collect_prefix_and_next(size_t index, std::vector<uint32_t> & prefix, std::vector<prefix_and_next> & out) {
if (!nodes[index].is_word) {
if (!nodes[index].children.empty()) {
std::vector<uint32_t> chars;
chars.reserve(nodes[index].children.size());
for (const auto & p : nodes[index].children) {
chars.push_back(p.first);
}
out.emplace_back(prefix_and_next{prefix, chars});
}
}
for (const auto & p : nodes[index].children) {
uint32_t ch = p.first;
auto child = p.second;
prefix.push_back(ch);
collect_prefix_and_next(child, prefix, out);
prefix.pop_back();
}
}
size_t create_node() {
size_t index = nodes.size();
nodes.emplace_back();
@@ -153,6 +121,65 @@ struct trie {
}
};
// Aho-Corasick automaton
struct aho_corasick {
trie t;
std::vector<size_t> fail; // failure links
std::vector<size_t> order; // states in BFS order
std::vector<bool> terminal; // match states (directly or via a suffix link)
std::set<uint32_t> alphabet; // every character with a transition
aho_corasick(const std::vector<std::string> & strings) : t(strings) {
const auto & nodes = t.nodes;
const size_t n = nodes.size();
fail.assign(n, 0);
order.reserve(n);
std::deque<size_t> queue{ 0 };
while (!queue.empty()) {
size_t u = queue.front();
queue.pop_front();
order.push_back(u);
for (const auto & [ch, v] : nodes[u].children) {
if (u != 0) {
size_t f = fail[u];
while (f && nodes[f].children.find(ch) == nodes[f].children.end()) {
f = fail[f];
}
auto it = nodes[f].children.find(ch);
fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0;
}
queue.push_back(v);
}
}
terminal.assign(n, false);
for (size_t u : order) {
terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]);
}
for (const auto & node : nodes) {
for (const auto & [ch, v] : node.children) {
alphabet.insert(ch);
}
}
}
size_t num_states() const { return t.nodes.size(); }
bool is_terminal(size_t s) const { return terminal[s]; }
// follow failure links until a transition on `ch` exists.
size_t next(size_t state, uint32_t ch) const {
const auto & nodes = t.nodes;
while (state && nodes[state].children.find(ch) == nodes[state].children.end()) {
state = fail[state];
}
auto it = nodes[state].children.find(ch);
return it != nodes[state].children.end() ? it->second : 0;
}
};
static std::pair<uint32_t, size_t> parse_hex_escape(const std::string & str, size_t pos, int hex_count) {
if (pos + hex_count > str.length()) {
return {0, 0};
@@ -894,6 +921,10 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
};
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser>) {
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser>) {
p.child = resolve_ref(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
p.child = resolve_ref(p.child);
@@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() {
}
std::string common_peg_arena::dump(common_peg_parser_id id) const {
std::unordered_set<common_peg_parser_id> visited;
std::set<common_peg_parser_id> visited;
return dump_impl(id, visited);
}
std::string common_peg_arena::dump_impl(common_peg_parser_id id,
std::unordered_set<common_peg_parser_id> & visited) const {
std::set<common_peg_parser_id> & visited) const {
// Check for cycles
if (visited.count(id)) {
return "[cycle]";
@@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() {
common_peg_parser common_peg_parser_builder::json_array() {
return rule("json-array", [this]() {
auto ws = space();
auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))});
auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))});
return sequence({
literal("["),
ws,
@@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
});
}
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
if (delimiters.empty()) {
throw std::runtime_error("ac parser requires at least one delimiter");
}
return add(common_peg_ac_parser{p, delimiters});
}
static std::string gbnf_escape_char_class(uint32_t c) {
if (c == '-' || c == ']' || c == '[' || c == '\\') {
return "\\" + std::string(1, (char) c);
@@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) {
return std::string(buf);
}
static std::string gbnf_excluding_pattern(const std::vector<std::string> & strings) {
trie matcher(strings);
auto pieces = matcher.collect_prefix_and_next();
std::string pattern;
std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end
for (size_t i = 0; i < pieces.size(); ++i) {
if (i > 0) {
pattern += " | ";
}
const auto & pre = pieces[i].prefix;
const auto & chars = pieces[i].next_chars;
std::string cls;
cls.reserve(chars.size());
for (uint32_t ch : chars) {
cls += gbnf_escape_char_class(ch);
}
if (!pre.empty()) {
std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre));
pattern += pre_literal + " [^" + cls + "]";
// Each interior alternative consumes a delimiter-prefix plus a disambiguating
// char, so the repetition alone cannot match a value that *ends* on a proper
// prefix of a delimiter (e.g. a trailing "\n" when the delimiter is
// "\n</parameter>\n"). The runtime until() (greedy first-match) accepts such
// values, so without this the grammar would reject input the parser accepts.
// Allow the value to terminate on any proper prefix as an optional tail.
// This makes the grammar a slight superset of the runtime language (a value
// may end on the longest prefix, which greedy first-match would not itself
// produce); harmless for constrained generation, which only needs to admit
// every runtime-valid string.
if (!trailing.empty()) {
trailing += " | ";
}
trailing += pre_literal;
} else {
pattern += "[^" + cls + "]";
}
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
std::string result = "(" + pattern + ")*";
if (!trailing.empty()) {
result += " (" + trailing + ")?";
}
return result;
return s + "]";
}
static std::unordered_set<std::string> collect_reachable_rules(
static std::string gbnf_ac_grammar(
const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings,
const std::function<std::string(const std::vector<uint32_t> &,
const std::map<size_t, std::vector<uint32_t>> &,
const std::vector<uint32_t> &,
const std::function<std::string(size_t)> &)> & build_rule) {
aho_corasick ac(strings);
auto state_name = [&](size_t s) -> std::string {
if (s == 0) {
return prefix;
}
std::string num = std::to_string(s);
num = num.size() == 1 ? ("0" + num) : num;
return prefix + "-" + num;
};
for (size_t q = 0; q < ac.num_states(); q++) {
if (ac.is_terminal(q)) {
continue; // match states
}
std::map<size_t, std::vector<uint32_t>> buckets;
std::vector<uint32_t> completing; // chars that complete a delimiter
std::vector<uint32_t> specific; // chars with an explicit transition
for (uint32_t c : ac.alphabet) {
size_t d = ac.next(q, c);
if (ac.is_terminal(d)) {
completing.push_back(c);
specific.push_back(c);
} else if (d != 0) {
buckets[d].push_back(c); // specific non-root destination
specific.push_back(c);
}
}
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
}
// An empty delimiter makes the start state terminal. Emit an entry rule
// that matches the empty string so the returned reference stays valid.
if (ac.is_terminal(0)) {
builder.add_rule(prefix, "|");
}
return state_name(0);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & /*completing*/,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
// every state is accepting and completing chars get no
// alternative, so a forbidden string can never be matched
std::string rhs = "|";
for (const auto & [d, chars] : buckets) {
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
return rhs;
});
}
// GBNF grammar matching everything up to and including the first occurrence of
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
// the start state rule name.
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & completing,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
std::vector<std::string> alts;
if (!completing.empty()) {
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
}
for (const auto & [d, chars] : buckets) {
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
}
// every other character keeps scanning from the start state
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
return string_join(alts, " | ");
});
}
static std::set<std::string> collect_reachable_rules(
const common_peg_arena & arena,
const common_peg_parser_id & rule
) {
std::unordered_set<std::string> reachable;
std::unordered_set<std::string> visited;
std::set<std::string> reachable;
std::set<std::string> visited;
std::function<void(common_peg_parser_id)> visit = [&](common_peg_parser_id id) {
const auto & parser = arena.get(id);
@@ -1588,6 +1686,7 @@ static std::unordered_set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser> ||
std::is_same_v<T, common_peg_schema_parser>) {
visit(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
if (p.delimiters.empty()) {
return ".*";
}
return gbnf_excluding_pattern(p.delimiters);
return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters);
} else if constexpr (std::is_same_v<T, common_peg_schema_parser>) {
if (schema_delegates(p)) {
return to_gbnf(p.child);
@@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return p.grammar;
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
} else {
static_assert(is_always_false_v<T>);
}
@@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
};
// Collect reachable rules
std::unordered_set<std::string> reachable_rules;
std::set<std::string> reachable_rules;
if (lazy) {
// Collect rules reachable from trigger rules
@@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
};
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
}
}, variant);
}
@@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
};
}
if (type == "ac") {
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
}
return common_peg_ac_parser{
j["child"].get<common_peg_parser_id>(),
j["delimiters"].get<std::vector<std::string>>(),
};
}
throw std::runtime_error("Unknown parser type: " + type);
}
+16 -3
View File
@@ -3,8 +3,8 @@
#include <nlohmann/json_fwd.hpp>
#include <memory>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <string>
#include <string_view>
#include <functional>
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
std::string grammar;
};
struct common_peg_ac_parser {
common_peg_parser_id child;
std::vector<std::string> delimiters;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser,
common_peg_gbnf_parser
common_peg_gbnf_parser,
common_peg_ac_parser
>;
class common_peg_arena {
@@ -335,7 +341,7 @@ class common_peg_arena {
friend class common_peg_parser_builder;
private:
std::string dump_impl(common_peg_parser_id id, std::unordered_set<common_peg_parser_id> & visited) const;
std::string dump_impl(common_peg_parser_id id, std::set<common_peg_parser_id> & visited) const;
common_peg_parser_id add_parser(common_peg_parser_variant parser);
void add_rule(const std::string & name, common_peg_parser_id id);
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
// the child's grammar. Parsing delegates entirely to the child.
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
// automaton of `delimiters`, matching everything up to and including the
// first delimiter. Parsing delegates entirely to the child, which is
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();
+1 -49
View File
@@ -16,48 +16,6 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
// only allow a subset of args for remote presets for security reasons
// do not add more args unless absolutely necessary
// args that output to files are strictly prohibited
static std::set<std::string> get_remote_preset_whitelist(const std::map<std::string, common_arg> & key_to_opt) {
static const std::set<std::string> allowed_options = {
"model-url",
"hf-repo",
"hf-repo-draft",
"hf-repo-v", // vocoder
"hf-file-v", // vocoder
"mmproj-url",
"pooling",
"jinja",
"batch-size",
"ubatch-size",
"cache-reuse",
"chat-template-kwargs",
"mmap",
// note: sampling params are automatically allowed by default
// negated args will be added automatically if the positive arg is specified above
};
std::set<std::string> allowed_keys;
for (const auto & it : key_to_opt) {
const std::string & key = it.first;
const common_arg & opt = it.second;
if (allowed_options.find(key) != allowed_options.end() || opt.is_sampling) {
allowed_keys.insert(key);
// also add variant keys (args without leading dashes and env vars)
for (const auto & arg : opt.get_args()) {
allowed_keys.insert(rm_leading_dashes(arg));
}
for (const auto & env : opt.get_env()) {
allowed_keys.insert(env);
}
}
}
return allowed_keys;
}
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
std::vector<std::string> args;
@@ -300,16 +258,10 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
common_preset_context::common_preset_context(llama_example ex, bool only_remote_allowed)
common_preset_context::common_preset_context(llama_example ex)
: ctx_params(common_params_parser_init(default_params, ex)) {
common_params_add_preset_options(ctx_params.options);
key_to_opt = get_map_key_opt(ctx_params);
// setup allowed keys if only_remote_allowed is true
if (only_remote_allowed) {
filter_allowed_keys = true;
allowed_keys = get_remote_preset_whitelist(key_to_opt);
}
}
common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
+1 -1
View File
@@ -60,7 +60,7 @@ struct common_preset_context {
std::set<std::string> allowed_keys;
// if only_remote_allowed is true, only accept whitelisted keys
common_preset_context(llama_example ex, bool only_remote_allowed = false);
common_preset_context(llama_example ex);
// load presets from INI file
common_presets load_from_ini(const std::string & path, common_preset & global) const;
+3
View File
@@ -259,6 +259,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
}
if (!grmr && !grammar_str.empty()) {
throw std::runtime_error("failed to parse grammar");
}
// Compute prefill tokens from the generation prompt
std::vector<llama_token> prefill_tokens;
+174 -35
View File
@@ -161,6 +161,10 @@ struct common_speculative_impl {
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;
@@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
(size_t) n_embd_dec * sizeof(float));
}
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
// their single-position checkpoints drop it on restore
bool need_boundary_stash() const {
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
}
bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
if (!need_boundary_stash()) {
return false;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
return false;
}
const llama_pos pos = pending_pos_last[seq_id];
const std::vector<float> & g = pending_g_last[seq_id];
data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
std::memcpy(data.data(), &pos, sizeof(llama_pos));
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
return true;
}
void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
if (!need_boundary_stash()) {
return;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
return;
}
llama_pos pos = -1;
std::memcpy(&pos, data.data(), sizeof(llama_pos));
pending_pos_last[seq_id] = pos;
pending_g_last[seq_id].resize(n_embd_dec);
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
}
bool need_embd() const override {
return false;
}
@@ -858,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// One MTP draft driver, three modes (set once in the ctor):
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
// neither (qwen35 / qwen35moe): a single trained MTP head.
int32_t n_mtp_layers = 1;
bool is_mem_shared = false; // gemma4
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
@@ -873,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
// Per-seq draft length from the last draft() call, used in accept() to
// roll back ctx_dft's recurrent state past the AR draft's redundant
// pre-advancement before process() mirrored the verify batch.
std::vector<uint16_t> last_n_drafted;
std::vector<int> i_last;
std::vector<std::vector<float>> chain_h;
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
@@ -889,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -935,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
if (chain_heads) {
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
chain_h.assign(n_seq, {});
for (auto & c : chain_h) {
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
}
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_last.assign(n_seq, -1);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
last_n_drafted.assign(n_seq, 0);
}
~common_speculative_impl_draft_mtp() override {
@@ -1050,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
auto * mem_dft = llama_get_memory(ctx_dft);
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
}
llama_set_nextn_layer_offset(ctx_dft, head);
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
if (!ok) {
return false;
}
}
@@ -1087,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const float * h_row = nullptr;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1102,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
h_row = pending_h[seq_id].data();
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
i_last[seq_id] = batch.n_tokens - 1;
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
if (chain_heads) {
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
}
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
// each step decodes under a different head, i.e. a different decoder layer, and
// KV is per layer. process() filled this layer's KV only for positions < n_past
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
// and select head i so it rebuilds its own layer's KV there; decoding just the
// latest token would leave its attention reading cells only another head wrote.
if (chain_heads) {
auto * mem_dft = llama_get_memory(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (drafting[seq_id]) {
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
}
}
llama_set_nextn_layer_offset(ctx_dft, i);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
// rebuild the batch for the next step: the growing-KV paths re-add only the
// new token (the KV already holds the prefix), while chained heads re-add the
// whole prefix at the next head. dropped sequences are simply not re-added.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1127,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
@@ -1163,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
if (is_mem_shared) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
for (int t = 0; t < n_rows; ++t) {
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
}
} else if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
i_last[seq_id] = batch.n_tokens - 1;
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
@@ -1196,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
}
}
@@ -1810,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@@ -1848,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_mtp) {
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}
@@ -2118,6 +2232,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
}
}
// TODO: support the case of more than one speculative implementations having a state
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->get_state(seq_id, data)) {
return true;
}
}
return false;
}
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
if (spec == nullptr) {
return;
}
for (auto & impl : spec->impls) {
impl->set_state(seq_id, data);
}
}
void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
+4
View File
@@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
// informs the speculative context that n_accepted tokens were accepted by the target model
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
// (optional) get/set internal state
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data);
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
+1 -1
View File
@@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel):
if (rope_dim := hparams.get("head_dim")) is None:
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+7 -1
View File
@@ -1119,8 +1119,10 @@ class TextModel(ModelBase):
rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True)
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True)
partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True)
original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True)
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
# Ensure global params are mirrored in rope_parameters
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
if local_rope_theta is not None:
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
@@ -1128,6 +1130,10 @@ class TextModel(ModelBase):
self.rope_parameters["rope_theta"] = rope_theta
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
self.rope_parameters["rope_type"] = rope_type
if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None:
self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None:
self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings
@classmethod
def __init_subclass__(cls):
+1 -1
View File
@@ -148,7 +148,7 @@ class ChatGLMModel(TextModel):
rope_dim = self.hparams["attention_dim"]
else:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)))
self.gguf_writer.add_add_bos_token(False)
rope_freq = 10000
if "rope_ratio" in self.hparams:
+1 -1
View File
@@ -161,7 +161,7 @@ class DeciModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+3 -3
View File
@@ -24,7 +24,7 @@ class ExaoneModel(TextModel):
assert (hparams["activation_function"] == "silu")
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = self.rope_parameters.get("partial_rotary_factor")
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
@@ -39,7 +39,7 @@ class ExaoneModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
@@ -104,7 +104,7 @@ class Exaone4Model(TextModel):
factor = rope_params.get("factor", 16.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model):
self.gguf_writer.add_head_count_kv(value_arr)
# handle n_rot differently for global vs swa layers
partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0)
partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0)
n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors
n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa)
self.gguf_writer.add_rope_dimension_count(n_rot_full)
+2 -2
View File
@@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel):
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_rope_dimension_count(
int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))
int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))
)
# MoE parameters - Use only routed expert count (shared experts handled separately)
@@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model):
super().set_gguf_parameters()
rope_dim = self.hparams["qk_rope_head_dim"]
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0)
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
# NextN/MTP prediction layers
+1 -1
View File
@@ -289,7 +289,7 @@ class LlamaModel(TextModel):
factor = rope_params.get("factor", 8.0)
low_freq_factor = rope_params.get("low_freq_factor", 1.0)
high_freq_factor = rope_params.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
old_context_len = rope_params.get("original_max_position_embeddings", 8192)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -154,7 +154,7 @@ class MimoV2Model(TextModel):
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"])
self.gguf_writer.add_rope_dimension_count(rope_dim)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
+6 -10
View File
@@ -32,11 +32,9 @@ class MiniCPMModel(TextModel):
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is not None:
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if long_factors or short_factors:
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
@@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is not None:
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if long_factors or short_factors:
rope_dims = self.hparams["qk_rope_head_dim"]
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
+4 -3
View File
@@ -125,17 +125,18 @@ class NemotronModel(TextModel):
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
# * Partial RoPE
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
rot_pct = self.rope_parameters["partial_rotary_factor"]
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
# * RopeScaling for Nemotron
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
factor = self.hparams.get("factor") or self.rope_parameters.get("factor")
if factor is None:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
else:
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
self.gguf_writer.add_rope_scaling_factor(factor)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
+9 -11
View File
@@ -18,7 +18,7 @@ class Phi2Model(TextModel):
model_arch = gguf.MODEL_ARCH.PHI2
def set_gguf_parameters(self):
rot_pct = self.find_hparam(["partial_rotary_factor"])
rot_pct = self.rope_parameters["partial_rotary_factor"]
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
@@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel):
n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
rms_eps = self.find_hparam(["rms_norm_eps"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
rope_dims = int(rot_pct * n_embd) // n_head
self.gguf_writer.add_context_length(max_pos_embds)
@@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel):
n_embd = self.find_hparam(["hidden_size", "n_embd"])
n_head = self.find_hparam(["num_attention_heads", "n_head"])
max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"])
orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"])
rot_pct = self.hparams.get("partial_rotary_factor", 1.0)
orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"]
rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0)
rope_dims = int(rot_pct * n_embd) // n_head
# write rope scaling for long context (128k) model
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is None:
long_factors = self.rope_parameters.get('long_factor')
short_factors = self.rope_parameters.get('short_factor')
if not long_factors:
return
scale = max_pos_embds / orig_max_pos_embds
rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower()
rope_scaling_type = self.rope_parameters.get('rope_type', '').lower()
if len(rope_scaling_type) == 0:
raise KeyError('Missing the required key rope_scaling.type')
@@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel):
self.gguf_writer.add_rope_scaling_attn_factors(attn_factor)
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
+1 -1
View File
@@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel):
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25)))
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
+1 -1
View File
@@ -28,7 +28,7 @@ class StableLMModel(TextModel):
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"])
rotary_factor = self.rope_parameters["partial_rotary_factor"]
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
+1 -1
View File
@@ -314,7 +314,7 @@ class Step35Model(TextModel):
factor = float(rope_params.get("factor", 8.0))
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
old_context_len = int(rope_params.get("original_max_position_embeddings", 8192))
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
+1 -1
View File
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
```
$ apt update && apt upgrade -y
$ apt install git cmake
$ apt install git cmake libandroid-spawn
```
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
+62 -2
View File
@@ -161,6 +161,64 @@ You could update your test result in it directly.
Please refer to [Docker with SYCL](../docker.md#docker-with-sycl) for details.
## Quick Development WOW
This chapter is for quick development & try with SYCL backend on Intel GPU.
You need to install following sofeware before development:
- Intel GPU driver
- oneAPI package
- other development tools.
Please refer to [Linux](#linux) or [Windows](#windows-1) for above installation and resolve the trouble in usage. There are the detailed guide.
- Linux
```
## build from source code
./examples/sycl/build.sh
## run CONV_2D_DW unit test cases
./build/bin/test-backend-ops -b SYCL0 -o CONV_2D_DW
## run all unit test cases
./build/bin/test-backend-ops -b SYCL0
## run with LLM on the first GPU
./examples/sycl/test.sh -mg 0 -m xxxx.gguf
## run service with LLM on the first GPU
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
./examples/sycl/start-svr.sh -m xxxx.gguf
## update the docs/ops.md for new/update OPs
./examples/sycl/update-ops-doc.sh
```
- Windows
```
## build from source code
examples\sycl\win-build-sycl.bat
## run CONV_2D_DW unit test cases
build\bin\test-backend-ops.exe -b SYCL0 -o CONV_2D_DW
## run all unit test cases
build\bin\test-backend-ops.exe -b SYCL0
## run LLM on the first GPU
examples\sycl\win-test.bat -mg 0 -m xxxx.gguf
## run service with LLM on the first GPU
set ONEAPI_DEVICE_SELECTOR="level_zero:0"
examples\sycl\win-start-svr.bat -m xxxx.gguf
## update the docs/ops.md for new/update OPs
examples\sycl\win-update-ops-doc.bat
```
## Linux
### I. Setup Environment
@@ -701,7 +759,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| GGML_SYCL_HOST_MEM_FALLBACK | ON *(default)* \|OFF *(Optional)* | Allow host memory fallback when device memory is full during quantized weight reorder. Enables inference to continue at reduced speed (reading over PCIe) instead of failing. Requires Linux kernel 6.8+. |
| GGML_SYCL_SUPPORT_LEVEL_ZERO | ON *(default)* \|OFF *(Optional)* | Enable Level Zero API for device memory allocation. Requires Level Zero headers/library at build time and Intel GPU driver (Level Zero runtime) at run time. Reduces system RAM usage during multi-GPU inference. |
| GGML_SYCL_SUPPORT_LEVEL_ZERO_API | ON *(default)* \|OFF *(Optional)* | Support to use Level Zero API for device memory allocation. Requires Level Zero headers/library at build time and Intel GPU driver (Level Zero runtime) at run time. Reduces system RAM usage during multi-GPU inference. SYCL backend always runs on Level Zero running time even if it's set as OFF (The SYCL api will be usage for memory allocation).|
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
@@ -712,10 +770,11 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DEV2DEV_MEMCPY | 0 (default) or 1 | Choose the SYCL or L0 API in dev2dev memory copy.<br>Value: <br>* 0: SYCL API (default)<br>* 1: L0 API -- L0 API is found to lead to abnormal crash in some case. This debug flag is used to check the issue.|
| GGML_SYCL_ENABLE_FLASH_ATTN | 1 (default) or 0| Enable Flash-Attention. It can reduce memory usage. The performance impact depends on the LLM.|
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for Intel devices older than Gen 10) |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
| GGML_SYCL_ENABLE_LEVEL_ZERO | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO=ON at build time. |
| GGML_SYCL_USE_LEVEL_ZERO_API | 1 (default) or 0 | Use Level Zero API for device memory allocation instead of SYCL. Reduces system RAM usage on Intel dGPUs by avoiding DMA-buf/TTM host memory staging. Requires GGML_SYCL_SUPPORT_LEVEL_ZERO_API=ON at build time. SYCL backend always runs on Level Zero running time even if it's set as OFF (The SYCL api will be usage for memory allocation).|
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| GGML_SYCL_ENABLE_VMM | 0 or 1 (default) | Enable the virtual-memory device pool. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
@@ -731,6 +790,7 @@ Pass these via `CXXFLAGS` or add a one-off `#define` to enable a flag on the spo
| DEBUG_SYCL_POOL | Enable device memory pool logging on teardown. Useful for profiling allocations. |
| DEBUG_SYCL_MALLOC | Enable verbose per-call logging of device pool alloc/free operations. |
## Design Rule
- Open to all contributors.
+3 -2
View File
@@ -1,10 +1,11 @@
# Multimodal
llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature:
- [llama-mtmd-cli](../tools/mtmd/README.md)
- [llama-cli](../tools/cli/README.md)
- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API
- [llama-mtmd-cli](../tools/mtmd/README.md), for testing and development
Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality.
Currently, we support **image**, **audio** and **video** input.
To enable it, you can use one of the 2 methods below:
+4 -4
View File
@@ -27,11 +27,11 @@ Legend:
| COL2IM_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ❌ | ❌ |
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
+1840 -1840
View File
File diff suppressed because it is too large Load Diff
+36 -38
View File
@@ -8,55 +8,53 @@ The INI preset feature, introduced in [PR#17859](https://github.com/ggml-org/lla
When running multiple models on the server (router mode), INI preset files can be used to configure model-specific parameters. Please refer to the [server documentation](../tools/server/README.md) for more details.
### Using a Remote Preset
### Using a Hugging Face Preset
> [!NOTE]
> [!IMPORTANT]
>
> This feature is currently only supported via the `-hf` option.
> Please only use presets that you can trust! Unknown presets may be unsafe
For GGUF models hosted on Hugging Face, you can include a `preset.ini` file in the root directory of the repository to define specific configurations for that model.
You can push your preset to Hugging Face Hub and share with other users by:
1. Creating an empty model repository on Hugging Face
2. Creating a `preset.ini` file in the root directory of the repository
Example:
Example of a `preset.ini`:
```ini
hf-repo-draft = username/my-draft-model-GGUF
temp = 0.5
top-k = 20
top-p = 0.95
[*]
ctx-size = 0
mmap = 1
kv-unified = 1
parallel = 4
spec-default = 1
[Qwen3.5-4B]
hf = unsloth/Qwen3.5-4B-GGUF:Q4_K_M
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
[gpt-oss-120b-hf]
hf = ggml-org/gpt-oss-120b-GGUF
ctx-size = 262144
batch-size = 2048
ubatch-size = 2048
top-p = 1.0
top-k = 0
min-p = 0.01
temp = 1.0
chat-template-kwargs = {"reasoning_effort": "high"}
```
For security reasons, only certain options are allowed. Please refer to [preset.cpp](../common/preset.cpp) for the complete list of permitted options.
Example usage:
Assuming your repository `username/my-model-with-preset` contains a `preset.ini` with the configuration above:
```sh
llama-cli -hf username/my-model-with-preset
# This is equivalent to:
llama-cli -hf username/my-model-with-preset \
--hf-repo-draft username/my-draft-model-GGUF \
--temp 0.5 \
--top-k 20 \
--top-p 0.95
```
You can also override preset arguments by specifying them on the command line:
The preset will be loaded similarly to the `--models-preset` option. Therefore, you can also override certain params via CLI arguments:
```sh
# Force temp = 0.1, overriding the preset value
llama-cli -hf username/my-model-with-preset --temp 0.1
```
If you want to define multiple preset configurations for one or more GGUF models, you can create a blank HF repo for each preset. Each HF repo should contain a `preset.ini` file that references the actual model(s):
```ini
hf-repo = user/my-model-main
hf-repo-draft = user/my-model-draft
temp = 0.8
ctx-size = 1024
; (and other configurations)
llama-cli -hf username/my-preset --temp 0.1
```
### Named presets
+21 -21
View File
@@ -198,18 +198,18 @@ class BuiltinRule:
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
'boolean' : BuiltinRule('("true" | "false")', []),
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []),
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
'string' : BuiltinRule(r'"\"" char* "\""', ['char']),
'null' : BuiltinRule('"null"', []),
}
# TODO: support "uri", "email" string formats
@@ -217,9 +217,9 @@ STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
'date-string' : BuiltinRule('"\\"" date "\\""', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\""', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']),
}
DOTALL = '[\\U00000000-\\U0010FFFF]'
@@ -319,7 +319,7 @@ class SchemaConverter:
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
visit(trie)
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
out.append(f' ){"" if trie.is_end_of_string else "?"} ["]')
return ''.join(out)
def _add_rule(self, name, rule):
@@ -549,7 +549,7 @@ class SchemaConverter:
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"")
def _resolve_ref(self, ref):
@@ -580,10 +580,10 @@ class SchemaConverter:
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
elif 'enum' in schema:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
@@ -624,7 +624,7 @@ class SchemaConverter:
enum_intersection &= s
if enum_intersection:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
@@ -638,12 +638,12 @@ class SchemaConverter:
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
' space "]"')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"')
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
@@ -663,7 +663,7 @@ class SchemaConverter:
min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""')
elif schema_type in (None, 'integer') and \
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
@@ -680,7 +680,7 @@ class SchemaConverter:
out = ["("]
_generate_min_max_int(min_value, max_value, out)
out.append(") space")
out.append(")")
return self._add_rule(rule_name, ''.join(out))
elif (schema_type == 'object') or (len(schema) == 0):
@@ -765,7 +765,7 @@ class SchemaConverter:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
rule += ' space "}"'
return rule
+9
View File
@@ -0,0 +1,9 @@
#!/bin/bash
# MIT license
# Copyright (C) 2026 Intel Corporation
# SPDX-License-Identifier: MIT
./build/bin/test-backend-ops support --output csv > docs/ops/SYCL.csv
./scripts/create_ops_docs.py
+8
View File
@@ -0,0 +1,8 @@
@echo off
rem MIT license
rem Copyright (C) 2026 Intel Corporation
rem SPDX-License-Identifier: MIT
build\bin\test-backend-ops support --output csv > docs\ops\SYCL.csv
python scripts\create_ops_docs.py
+2 -5
View File
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 15)
set(GGML_VERSION_PATCH 1)
set(GGML_VERSION_PATCH 2)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@@ -249,7 +249,7 @@ option(GGML_SYCL "ggml: use SYCL"
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON)
option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON)
option(GGML_SYCL_SUPPORT_LEVEL_ZERO_API "ggml: use Level Zero API in SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
@@ -341,9 +341,6 @@ set(GGML_PUBLIC_HEADERS
include/gguf.h)
set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
#if (GGML_METAL)
# set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal")
#endif()
install(TARGETS ggml LIBRARY PUBLIC_HEADER)
install(TARGETS ggml-base LIBRARY)
+8 -1
View File
@@ -438,7 +438,14 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
# POWER11 backend: only if compiler supports -mcpu=power11
check_cxx_compiler_flag("-mcpu=power11" GGML_CXX_SUPPORTS_POWER11)
if (GGML_CXX_SUPPORTS_POWER11)
message(STATUS "Compiler supports -mcpu=power11, enabling POWER11 backend")
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
else()
message(STATUS "Skipping POWER11 backend: compiler does not support -mcpu=power11")
endif()
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
+1 -1
View File
@@ -389,7 +389,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}")
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
if (EXTRACTED_NUMBER EQUAL 10 OR EXTRACTED_NUMBER EQUAL 11)
list(APPEND ARCH_FLAGS -mcpu=power10)
elseif (EXTRACTED_NUMBER EQUAL 9)
list(APPEND ARCH_FLAGS -mcpu=power9)
+5 -6
View File
@@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
// Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
parallel_for_ggml(params, n_batch, [&](int begin, int end) {
for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
parallel_for_ggml(params, n_batch * M, [&](int begin, int end) {
for (int idx = begin; idx < end; ++idx) {
int batch_idx = idx / M;
int m = idx % M;
int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
const float * A_data = (const float *)((const char *)src1->data + src1_offset);
char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
for (int m = 0; m < M; ++m) {
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K);
}
});
});
+6 -5
View File
@@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC {
else if (n_aligned % 16 == 0) nc = 16;
else nc = 8;
}
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
bool can_use_tiled = n_aligned > 0 && (m % mc == 0);
if (can_use_tiled) {
matmul_tiled(m, n_aligned, mc, nc, kc);
if (n > n_aligned) {
@@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC {
int64_t ii = (job / xtiles) * mc;
int64_t jj = (job % xtiles) * nc;
for (int64_t kk = 0; kk < k; kk += kc) {
int64_t k_cur = MIN(kc, k - kk);
if constexpr(is_Ablock_q4) {
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
} else {
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack);
}
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack);
KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack);
}
}
}
+81
View File
@@ -0,0 +1,81 @@
#include "col2im-1d.cuh"
#include "convert.cuh"
// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach)
// columns: [K*OC, T_in] -> output: [T_out, OC]
// Supports F32, F16, BF16 data with F32 accumulator.
template <typename T>
static __global__ void col2im_1d_kernel(
const T * __restrict__ col,
T * __restrict__ dst,
const int T_in, const uint3 T_out_fd,
const int OC, const int K, const int K_OC,
const int s0, const int p0, const int total) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= total) return;
// dst layout: [T_out, OC], ne[0]=T_out fastest
const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out
const int oc = (int)qr.x;
const int t_out = (int)qr.y;
const int t_abs = t_out + p0; // absolute position in uncropped signal
// Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K
int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s)
if (t_in_min < 0) t_in_min = 0;
int t_in_max = t_abs / s0;
if (t_in_max >= T_in) t_in_max = T_in - 1;
float sum = 0.0f;
for (int t_in = t_in_min; t_in <= t_in_max; t_in++) {
const int k = t_abs - t_in * s0;
// col layout: [K*OC, T_in], column index = oc * K + k
sum += ggml_cuda_cast<float>(col[(oc * K + k) + t_in * K_OC]);
}
dst[idx] = ggml_cuda_cast<T>(sum);
}
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int K_OC = (int) src0->ne[0];
const int T_in = (int) src0->ne[1];
const int K = K_OC / OC;
const int T_out = (int) dst->ne[0];
const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out);
const int total = T_out * OC;
const int block_size = 256;
const int num_blocks = (total + block_size - 1) / block_size;
switch (src0->type) {
case GGML_TYPE_F32: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const float *)src0->data, (float *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
case GGML_TYPE_F16: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const half *)src0->data, (half *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
case GGML_TYPE_BF16: {
col2im_1d_kernel<<<num_blocks, block_size, 0, stream>>>(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
T_in, T_out_fd, OC, K, K_OC, s0, p0, total);
} break;
default:
GGML_ABORT("col2im_1d: unsupported type");
}
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+21 -66
View File
@@ -11,6 +11,7 @@
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/col2im-1d.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
@@ -622,18 +623,6 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
// cuda buffer
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::mutex device_mutex;
int active_count = 0;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
};
struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
@@ -651,13 +640,6 @@ struct ggml_backend_cuda_buffer_context {
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete ctx;
}
@@ -810,12 +792,6 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}
@@ -1515,12 +1491,6 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
}
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
CUDA_CHECK(cudaFreeHost(buffer->context));
}
@@ -1529,8 +1499,6 @@ static void * ggml_cuda_host_malloc(size_t size) {
return nullptr;
}
ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.
void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
@@ -1556,12 +1524,6 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return buffer;
}
@@ -3090,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
case GGML_OP_COL2IM_1D:
ggml_cuda_op_col2im_1d(ctx, dst);
break;
case GGML_OP_POOL_2D:
ggml_cuda_op_pool2d(ctx, dst);
break;
@@ -3179,12 +3144,6 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
static void ggml_backend_cuda_free(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete cuda_ctx;
delete backend;
}
@@ -4916,6 +4875,14 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
// backend device
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->name.c_str();
@@ -5004,11 +4971,6 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::lock_guard<std::mutex> lock(ctx->device_mutex);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemGetInfo(free, total));
@@ -5035,13 +4997,6 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
}
#endif // defined(__linux__)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
// context that permanently consumes VRAM. Reset the device to free it.
if (ctx->active_count == 0) {
CUDA_CHECK(cudaDeviceReset());
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -5365,6 +5320,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
return false;
} break;
case GGML_OP_COL2IM_1D:
{
ggml_type src0_type = op->src[0]->type;
return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) &&
op->type == src0_type &&
ggml_is_contiguous(op->src[0]) &&
ggml_is_contiguous(op);
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
@@ -5745,21 +5708,13 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
return nullptr;
}
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);
ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ dev,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .context = */ ctx,
};
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return cuda_backend;
}
+104 -12
View File
@@ -69,6 +69,7 @@ static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE;
static int opt_opbatch = 1024; // max number of ops in a batch
static int opt_opqueue = 16; // max number of pending batches
static int opt_oppoll = 0; // polling for batch completions
static int opt_optrace = 0; // trace buffer size per thread (0 means default)
static std::regex* opt_opfilter = NULL; // regex of ops to not claim
@@ -118,20 +119,39 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct
ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no");
}
static const char * htp_event_name(uint16_t id) {
switch (id) {
case HTP_TRACE_EVT_DMA: return "DMA";
case HTP_TRACE_EVT_HVX_COMP: return "HVX_COMP";
case HTP_TRACE_EVT_HVX_A_QUANT: return "HVX_A_QUANT";
case HTP_TRACE_EVT_HVX_A_PREP: return "HVX_A_PREP";
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
default: return "UNKNOWN";
}
}
static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node,
uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) {
const htp_prof_desc & pd) {
if (!opt_profile) return;
uint32_t op_usec = pd.usecs;
uint32_t op_cycles = pd.cycles_stop - pd.cycles_start;
const uint32_t * pmu = pd.pmu;
char pmu_str[256] = "";
if (opt_profile > 1) {
if (opt_profile == 2) {
static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters");
sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]",
pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]);
}
htp_opformat fmt(node);
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str);
float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f;
GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(),
node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str);
}
// ** backend sessions
@@ -1995,10 +2015,16 @@ struct ggml_hexagon_opqueue {
size_t n_ops = batch_size;
size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS;
size_t tr_size = 0;
if (opt_profile == 3) {
tr_size = (HTP_MAX_NTHREADS + 1) * opt_optrace * sizeof(htp_trace_desc);
}
shm_blk_size = sizeof(htp_buf_desc) * n_bufs +
sizeof(htp_tensor) * n_tensors +
sizeof(htp_op_desc) * n_ops +
sizeof(htp_prof_desc) * n_ops;
sizeof(htp_prof_desc) * n_ops +
tr_size;
shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */);
@@ -2042,11 +2068,19 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * req.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * req.n_ops;
size_t tr_size = 0;
if (opt_profile == 3) {
req.n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(htp_trace_desc);
} else {
req.n_traces = 0;
}
dbuf.ptr = shm_buf->base + (req.id * shm_blk_size);
dbuf.fd = shm_buf->fd;
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base;
dbuf.size = b_size + t_size + o_size + p_size;
dbuf.size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(dbuf.size <= shm_blk_size);
@@ -2092,7 +2126,14 @@ struct ggml_hexagon_opqueue {
const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops;
const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops;
const size_t m_size = b_size + t_size + o_size + p_size;
size_t tr_size = 0;
uint32_t n_traces = 0;
if (opt_profile == 3) {
n_traces = opt_optrace;
tr_size = (HTP_MAX_NTHREADS + 1) * n_traces * sizeof(htp_trace_desc);
}
const size_t m_size = b_size + t_size + o_size + p_size + tr_size;
GGML_ASSERT(m_size <= shm_blk_size);
HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n",
@@ -2111,13 +2152,62 @@ struct ggml_hexagon_opqueue {
GGML_ASSERT(rsp.n_ops <= ops.size());
const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr;
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu);
const htp_trace_desc * trace_events = nullptr;
if (opt_profile == 3) {
trace_events = (const htp_trace_desc *) (p_ptr + p_size);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec);
uint32_t trace_idx[HTP_MAX_NTHREADS + 1] = {0};
uint32_t valid_cnt[HTP_MAX_NTHREADS + 1] = {0};
if (opt_profile == 3) {
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
uint32_t count = rsp.n_traces[t];
valid_cnt[t] = count > n_traces ? n_traces : count;
}
}
for (uint32_t i = 0; i < rsp.n_ops; i++) {
htp_usec += pd[i].usecs;
ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i]);
if (opt_profile == 3) {
uint32_t op_duration = pd[i].cycles_stop - pd[i].cycles_start;
for (uint32_t t = 0; t <= HTP_MAX_NTHREADS; t++) {
while (trace_idx[t] < valid_cnt[t]) {
const auto & e = trace_events[t * n_traces + trace_idx[t]];
uint32_t offset = e.cycles - pd[i].cycles_start;
if (offset >= 0x80000000) {
trace_idx[t]++;
continue;
}
if (offset > op_duration) {
break;
}
bool is_stop = (e.info & 0x8000) != 0;
uint16_t info = e.info & 0x7FFF;
GGML_LOG_DEBUG("ggml-hex: %s trace-op %s: thread %u event %s info %u %s %u\n",
shm_buf->sess->c_name(), ops[i].op_name().c_str(), t, htp_event_name(e.id), info, is_stop ? "stop" : "start", e.cycles);
trace_idx[t]++;
}
}
}
}
char evt_str[256] = "";
if (opt_profile == 3) {
sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]",
rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3],
rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7],
rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]);
}
GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u%s\n",
shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec, evt_str);
}
}
};
@@ -3901,6 +3991,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH");
const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE");
const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL");
const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE");
const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER");
const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
const char * str_etm = getenv("GGML_HEXAGON_ETM");
@@ -3939,6 +4030,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch;
opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue;
opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll;
opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128);
opt_profile = str_profile ? atoi(str_profile) : 0;
opt_etm = str_etm ? atoi(str_etm) : 0;
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
+1 -1
View File
@@ -37,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-matmul-ops.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
@@ -339,6 +339,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
if (ir0 >= ir1) return;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
@@ -615,6 +618,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
+10 -3
View File
@@ -6,6 +6,8 @@
#include <stdbool.h>
#include <stdint.h>
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
#endif
@@ -88,6 +90,7 @@ typedef struct {
uint32_t pop_idx;
uint32_t capacity;
uint32_t idx_mask;
struct htp_thread_trace * trace;
} dma_queue;
dma_queue * dma_queue_create(size_t capacity);
@@ -152,6 +155,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (size) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
@@ -202,6 +206,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
q->dptr[q->push_idx] = dptr;
if (nrows) {
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
dmlink(q->tail, desc);
q->tail = desc;
} else {
@@ -223,10 +228,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
dmpoll();
if (!desc->done) {
while (!desc->done) {
dmpoll();
}
}
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_DMA, q->pop_idx);
dptr = q->dptr[q->pop_idx];
+64
View File
@@ -0,0 +1,64 @@
#ifndef HEX_PROFILE_H
#define HEX_PROFILE_H
#include <stdbool.h>
#include <stdint.h>
#include <qurt.h>
#include "hex-utils.h"
#include "htp-ops.h"
#define HTP_TRACE_EVT_START 0
#define HTP_TRACE_EVT_STOP 1
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
#endif
}
struct htp_thread_trace {
uint32_t count;
uint32_t max_events;
struct htp_trace_desc * events;
};
static inline void htp_trace_event(struct htp_thread_trace * tr, uint16_t id, uint16_t info, uint32_t type) {
if (tr && tr->events && tr->count < tr->max_events) {
uint32_t idx = tr->count;
tr->events[idx].id = id;
tr->events[idx].info = info | (type == HTP_TRACE_EVT_STOP ? 0x8000 : 0);
tr->events[idx].cycles = (uint32_t) hex_get_cycles();
tr->count++;
}
}
static inline void htp_trace_event_start(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_START);
}
static inline void htp_trace_event_stop(struct htp_thread_trace * tr, uint16_t id, uint16_t info) {
htp_trace_event(tr, id, info, HTP_TRACE_EVT_STOP);
}
#endif /* HEX_PROFILE_H */
-27
View File
@@ -107,31 +107,4 @@ static inline void hex_pause() {
asm volatile(" pause(#255)\n");
}
#ifndef HEX_NUM_PMU_COUNTERS
#define HEX_NUM_PMU_COUNTERS 8
#endif
static inline void hex_get_pmu(uint32_t counters[]) {
#if __HVX_ARCH__ >= 79
asm volatile("%0 = upmucnt0" : "=r"(counters[0]));
asm volatile("%0 = upmucnt1" : "=r"(counters[1]));
asm volatile("%0 = upmucnt2" : "=r"(counters[2]));
asm volatile("%0 = upmucnt3" : "=r"(counters[3]));
asm volatile("%0 = upmucnt4" : "=r"(counters[4]));
asm volatile("%0 = upmucnt5" : "=r"(counters[5]));
asm volatile("%0 = upmucnt6" : "=r"(counters[6]));
asm volatile("%0 = upmucnt7" : "=r"(counters[7]));
#else
counters[0] = qurt_pmu_get(QURT_PMUCNT0);
counters[1] = qurt_pmu_get(QURT_PMUCNT1);
counters[2] = qurt_pmu_get(QURT_PMUCNT2);
counters[3] = qurt_pmu_get(QURT_PMUCNT3);
counters[4] = qurt_pmu_get(QURT_PMUCNT4);
counters[5] = qurt_pmu_get(QURT_PMUCNT5);
counters[6] = qurt_pmu_get(QURT_PMUCNT6);
counters[7] = qurt_pmu_get(QURT_PMUCNT7);
// qurt_pmu_get_pmucnt(counters);
#endif
}
#endif /* HEX_UTILS_H */
+26 -66
View File
@@ -18,7 +18,7 @@
#include "ggml-common.h"
#include "hex-dma.h"
#include "hex-fastdiv.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "hmx-queue.h"
#include "hmx-utils.h"
#include "htp-ctx.h"
@@ -367,8 +367,11 @@ static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK,
(int) args->src_stride, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) {
@@ -408,8 +411,11 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV,
(int) args->src_stride, (int) args->n_col_tiles, start, end);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_v_interleave(struct hmx_fa_context * factx,
@@ -462,6 +468,9 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * q = args->q;
const uint32_t q_start = args->q_start;
const uint32_t kv_head = args->kv_head;
@@ -515,6 +524,7 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_q_load(struct hmx_fa_context * factx,
@@ -566,6 +576,9 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
const struct htp_tensor * dst = args->dst;
const __fp16 * o_tile_src = args->o_tile_src;
const uint32_t q_start = args->q_start;
@@ -611,6 +624,7 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) {
}
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, start);
}
static void fa_phase_o_store(struct hmx_fa_context * factx,
@@ -680,6 +694,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
return;
}
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
// Per-thread row scratch: thread i uses bufs at offset i * 2 * stride
const size_t row_buf_stride = factx->row_buf_stride;
HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride;
@@ -950,6 +967,7 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v;
factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v;
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, vec_start);
}
// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads).
@@ -1245,6 +1263,7 @@ static __attribute__((noinline)) void fa_compute_slopes(
// ============================================================================
int hmx_flash_attn_ext(struct htp_ops_context * octx) {
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[HTP_MAX_NTHREADS] : NULL;
const struct htp_tensor * q = octx->src[0];
const struct htp_tensor * k = octx->src[1];
const struct htp_tensor * v = octx->src[2];
@@ -1422,19 +1441,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
// Profiling timers
TIMER_DEFINE(total);
TIMER_DEFINE(q_load);
TIMER_DEFINE(kv_dma);
TIMER_DEFINE(k_interleave);
TIMER_DEFINE(v_interleave);
TIMER_DEFINE(qk_dot);
TIMER_DEFINE(softmax);
TIMER_DEFINE(o_update);
TIMER_DEFINE(o_norm);
TIMER_DEFINE(o_store);
TIMER_START(total);
// ======== DMA setup ========
dma_queue * const dma = ctx->dma[0];
@@ -1474,12 +1480,10 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS;
// ---- Load Q block [g_br, D] -> tiles, interleaving G heads ----
TIMER_START(q_load);
if (n_rows_g < g_br) {
hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes);
}
fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(q_load);
// ---- Initialize per-block state ----
hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes);
@@ -1558,10 +1562,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
// Wait for current KV DMA
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
// Push mask DMA for this block (single 2D DMA when broadcast)
bool has_mask_dma = false;
@@ -1583,10 +1585,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.DV = DV;
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
}
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ----
qk_job.q_tiles = factx.vtcm_q_tiles;
@@ -1597,15 +1596,11 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
qk_job.n_dot_tiles = DK / 32;
qk_job.n_tiles_per_bc = n_tiles_per_bc;
qk_job.hmx_scales = factx.vtcm_hmx_scales_qk;
TIMER_START(qk_dot);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job));
// DMA push next block (non-blocking, before worker_pool)
DMA_PREFETCH_KV(kv_blk + 1);
TIMER_START(v_interleave);
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// Pop and swap previous block's output update (deferred HMX pop)
if (kv_blk > 0) {
@@ -1615,7 +1610,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// Pop current block's dot product job
hmx_queue_pop(hmx_q);
TIMER_STOP(qk_dot);
// ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ----
// Pop mask DMA before softmax (ensures VTCM buffer is ready)
@@ -1641,10 +1635,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
buf_idx = 1 - buf_idx;
} // end KV block loop (pipeline)
@@ -1664,11 +1655,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
ou_job.n_row_tiles_g_br = n_row_tiles_g_br;
ou_job.n_tiles_per_bc = n_tiles_per_bc;
ou_job.DV = DV;
TIMER_START(o_update);
hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job));
hmx_queue_pop(hmx_q);
TIMER_STOP(o_update);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
@@ -1683,23 +1671,14 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const uint32_t kv_start = kv_blk * Bc;
const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start);
const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS);
TIMER_START(kv_dma);
dma_queue_pop(dma); // K
dma_queue_pop(dma); // V
TIMER_STOP(kv_dma);
bool has_mask_dma = false;
MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma);
DMA_PREFETCH_KV(kv_blk + 1);
// K interleave (multi-thread HVX)
TIMER_START(k_interleave);
fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx);
TIMER_STOP(k_interleave);
// QK dot (inline HMX on main thread)
TIMER_START(qk_dot);
{
const size_t n_dot_tiles = (size_t) (DK / 32);
const __fp16 * restrict q_base = factx.vtcm_q_tiles;
@@ -1709,6 +1688,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(n_dot_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < n_col_tiles; ++c) {
@@ -1724,8 +1704,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(out_tile, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(qk_dot);
// Pop mask DMA
MASK_DMA_POP(has_mask_dma);
@@ -1751,21 +1731,9 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL;
sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride;
sargs.slopes = factx.vtcm_slopes;
TIMER_START(softmax);
fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br);
TIMER_STOP(softmax);
// V interleave (multi-thread HVX)
TIMER_START(v_interleave);
// FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout
// stride to match o_update's v_tile access. Using per-block n_col_tiles
// misplaces DV_tile 1..3 in the last partial KV block.
fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc);
TIMER_STOP(v_interleave);
// O update (inline HMX on main thread)
TIMER_START(o_update);
{
const size_t DV_tiles = (size_t) (DV / 32);
const __fp16 * restrict d_base = factx.vtcm_d_tiles;
@@ -1777,6 +1745,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_col_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1798,16 +1767,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_tile_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev);
}
TIMER_STOP(o_update);
buf_idx = 1 - buf_idx;
} // end KV block loop (fallback)
}
// ---- Final normalization: O = diag(1/l) @ O ----
TIMER_START(o_norm);
{
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
@@ -1830,6 +1798,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
__builtin_assume(n_row_tiles > 0);
__builtin_assume(DV_tiles > 0);
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id);
for (size_t r = 0; r < n_row_tiles; ++r) {
for (size_t c = 0; c < DV_tiles; ++c) {
@@ -1842,14 +1811,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
Q6_mxmem_AR_after_hf(o_out, 0);
}
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
}
TIMER_STOP(o_norm);
// ---- Store O block ----
TIMER_START(o_store);
fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g);
TIMER_STOP(o_store);
#undef MASK_DMA_PUSH
#undef MASK_DMA_POP
@@ -1865,14 +1832,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total),
TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave));
FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax),
TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store));
#endif
return HTP_STATUS_OK;
}
+55 -41
View File
@@ -27,7 +27,7 @@
#include "hmx-ops.h"
#include "hmx-utils.h"
#include "hmx-queue.h"
#include "hmx-profile.h"
#include "hex-profile.h"
#include "vtcm-utils.h"
@@ -430,6 +430,7 @@ typedef struct {
int n_tasks;
int n_k_tiles;
struct fastdiv_values n_k_tiles_div;
struct htp_thread_trace * traces;
} x4x2_dequantize_state_t;
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
@@ -533,11 +534,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(
\
static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \
int start = task_id * state->n_tiles_per_task; \
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \
dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \
} \
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \
}
DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16))
@@ -657,11 +661,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(
static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
@@ -717,11 +724,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(
static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void convert_f16_weight_to_fp16_tiles_task(
@@ -773,11 +783,14 @@ static void convert_f16_weight_to_fp16_tiles_task(
static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
convert_f16_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
static void quantize_f32_weight_to_fp16_tiles_task(
@@ -833,11 +846,14 @@ static void quantize_f32_weight_to_fp16_tiles_task(
static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) {
x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data;
struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) {
int start = task_id * state->n_tiles_per_task;
int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles);
quantize_f32_weight_to_fp16_tiles_task(state, start, end);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i);
}
@@ -868,6 +884,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
state.weight_type = weight_type;
state.n_k_tiles = n_k_tiles;
state.n_k_tiles_div = n_k_tiles_div;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
dequant_worker_fn(1, 0, &state);
@@ -985,10 +1002,13 @@ typedef struct {
int n_chunks_per_task;
int n_cols;
int n; // DDR row stride (total output columns)
struct htp_thread_trace * traces;
} output_transfer_task_state_t;
static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_task_state_t *st = (output_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
int chunk_idx = task_id * st->n_chunks_per_task;
@@ -998,6 +1018,7 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void
const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols;
transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src,
@@ -1015,6 +1036,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst,
state.vtcm_src = vtcm_src;
state.n_cols = n_cols;
state.n = n;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_output_chunk_worker_fn(1, 0, &state);
@@ -1086,10 +1108,13 @@ typedef struct {
int n_chunks_per_task;
int k_block;
int k_stride;
struct htp_thread_trace * traces;
} activation_transfer_task_state_t;
static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) {
// one chunk: one row
@@ -1100,6 +1125,7 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i,
const float *src = st->src + chunk_idx * st->k_stride;
transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) {
@@ -1117,6 +1143,7 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *
state.src = src;
state.k_block = k_block;
state.k_stride = k_stride;
state.traces = ctx ? ctx->trace : NULL;
if (state.n_tasks == 1 || n_threads == 1) {
transfer_activation_chunk_worker_fn(1, 0, &state);
@@ -1245,13 +1272,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu",
m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols);
@@ -1370,7 +1391,12 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
// C: HMX Compute (Synchronous)
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
// D: Output Store
float *output_chunk = dst + (mr * n + nc);
@@ -1380,18 +1406,7 @@ int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
}
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n);
if (!use_pipeline) {
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
size_t weight_size = (size_t)n * row_stride;
float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load);
FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth);
}
#endif
return 0;
}
@@ -1523,13 +1538,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
m_chunk_n_rows, n_chunk_n_cols,
(size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget);
TIMER_DEFINE(activation_load);
TIMER_DEFINE(weight_load);
TIMER_DEFINE(hmx_core);
TIMER_DEFINE(output_store);
TIMER_DEFINE(total);
TIMER_START(total);
const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16);
const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16);
@@ -1549,7 +1558,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
// contiguous rows into a VTCM scratch buffer first, then HVX
// converts from the contiguous VTCM buffer. This avoids L2 cache
// thrashing from HVX loads at large strides.
TIMER_START(activation_load);
for (int g = 0; g < group_size; ++g) {
const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride;
__fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
@@ -1569,7 +1577,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
params->k, params->act_stride, ctx->n_threads);
}
}
TIMER_STOP(activation_load);
void *buf_curr = vtcm_scratch0;
void *buf_next = vtcm_scratch1;
@@ -1584,7 +1591,6 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols);
const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS);
TIMER_START(weight_load);
{
dma_queue_pop(ctx->dma[0]);
@@ -1601,24 +1607,22 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
0, n_cols);
hex_swap_ptr(&buf_curr, &buf_next);
}
TIMER_STOP(weight_load);
// Reuse the interleaved weight for every q_head in this GQA group
for (int g = 0; g < group_size; ++g) {
TIMER_START(hmx_core);
{
const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride;
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles,
params->k / 32);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
TIMER_STOP(hmx_core);
TIMER_START(output_store);
{
float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc;
transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads);
}
TIMER_STOP(output_store);
}
}
}
@@ -1627,14 +1631,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
TIMER_STOP(total);
#if defined(ENABLE_PROFILE_TIMERS)
FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total),
params->m, params->k, params->n, group_size);
FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us",
TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store));
#endif
return 0;
}
@@ -1668,6 +1665,7 @@ typedef struct {
size_t nb12;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} activation_transfer_gathered_task_state_t;
typedef struct {
@@ -1684,6 +1682,7 @@ typedef struct {
size_t dst_nb2;
int start_row;
int cne1;
struct htp_thread_trace *traces;
} output_transfer_scattered_task_state_t;
static void transfer_activation_chunk_fp32_to_fp16_gathered(
@@ -1780,6 +1779,9 @@ static void transfer_activation_chunk_fp32_to_fp16_gathered(
static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) {
activation_transfer_gathered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1791,6 +1793,7 @@ static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigne
st->matrix_rows, st->cur_a, st->mapping_stride,
st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i);
}
static void transfer_activation_chunk_gathered_threaded(
@@ -1830,6 +1833,7 @@ static void transfer_activation_chunk_gathered_threaded(
.nb12 = nb12,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -1895,6 +1899,9 @@ static void transfer_output_chunk_fp16_to_fp32_scattered(
static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) {
output_transfer_scattered_task_state_t *st = data;
struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
int chunk_idx = i;
int chunk_size = st->n_chunks_per_task;
int start_row = st->start_row + chunk_idx * chunk_size;
@@ -1906,6 +1913,7 @@ static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned i
st->matrix_rows, st->cur_a, st->mapping_stride,
st->dst_nb1, st->dst_nb2, st->cne1);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i);
}
static void transfer_output_chunk_scattered_threaded(
@@ -1942,6 +1950,7 @@ static void transfer_output_chunk_scattered_threaded(
.dst_nb2 = dst_nb2,
.start_row = start_row,
.cne1 = cne1,
.traces = ctx ? ctx->trace : NULL,
};
if (actual_threads <= 1) {
@@ -2053,7 +2062,12 @@ int hmx_matmul_id_2d_f32(struct htp_context *ctx,
dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
{
struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS);
}
transfer_output_chunk_scattered_threaded(
ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols,
-34
View File
@@ -1,34 +0,0 @@
// Conditional fine-grained profiling macros for HMX operations.
//
// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this
// header) to instrument sub-operation latencies with HAP qtimer. When the
// macro is not defined the TIMER_* helpers expand to nothing so there is zero
// overhead.
//
// Usage:
// TIMER_DEFINE(my_phase); // declare accumulator variable
// TIMER_START(my_phase); // snapshot start time
// ... work ...
// TIMER_STOP(my_phase); // accumulate elapsed ticks
// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase));
#ifndef HMX_PROFILE_H
#define HMX_PROFILE_H
#include <HAP_perf.h>
// #define ENABLE_PROFILE_TIMERS
#if defined(ENABLE_PROFILE_TIMERS)
# define TIMER_DEFINE(name) int64_t name##_ticks = 0
# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count()
# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0
# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks)
#else
# define TIMER_DEFINE(name)
# define TIMER_START(name)
# define TIMER_STOP(name)
# define TIMER_US(name) 0LL
#endif
#endif // HMX_PROFILE_H
+2
View File
@@ -44,7 +44,9 @@ static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) {
case HMX_QUEUE_SUSPEND: hmx_unlock(q); break;
default:
hmx_lock(q);
htp_trace_event_start(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
d->func(d->data);
htp_trace_event_stop(q->trace, HTP_TRACE_EVT_HMX_COMP, ir);
break;
}
+2
View File
@@ -11,6 +11,7 @@
#include <HAP_farf.h>
#include "hex-utils.h"
#include "hex-profile.h"
#ifdef __cplusplus
extern "C" {
@@ -47,6 +48,7 @@ struct hmx_queue {
void * stack;
uint32_t hap_rctx;
bool hmx_locked;
struct htp_thread_trace * trace;
};
struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx);
+2
View File
@@ -4,6 +4,7 @@
#include "hex-dma.h"
#include "hmx-queue.h"
#include "htp-ops.h"
#include "hex-profile.h"
#include "worker-pool.h"
#include <assert.h>
@@ -70,6 +71,7 @@ struct htp_context {
bool hmx_enabled;
bool etm;
uint32_t profiler;
struct htp_thread_trace trace[HTP_MAX_NTHREADS + 1];
uint8_t * vtcm_base;
size_t vtcm_size;
+31 -4
View File
@@ -146,10 +146,36 @@ struct htp_op_desc {
uint16_t dst; // Output tensor index
};
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_TRACE_MAX_EVENTS 256
enum htp_profiler_mode {
HTP_PROF_DISABLED = 0,
HTP_PROF_BASIC = 1,
HTP_PROF_PMU = 2,
HTP_PROF_TRACE = 3,
};
enum htp_trace_event_id {
HTP_TRACE_EVT_DMA = 0,
HTP_TRACE_EVT_HVX_COMP = 20,
HTP_TRACE_EVT_HVX_A_QUANT = 21,
HTP_TRACE_EVT_HVX_A_PREP = 22,
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
HTP_TRACE_EVT_HVX_W_PREP = 24,
HTP_TRACE_EVT_HVX_O_PROC = 25,
HTP_TRACE_EVT_HMX_COMP = 40,
};
struct htp_trace_desc {
uint32_t cycles; // lower 32-bits of cycle counter
uint16_t id; // Event ID
uint16_t info; // bit 15: is_stop. bits 14-0: tile/chunk index or other metadata.
};
#define HTP_PROF_PMU_NCNT 8
@@ -158,8 +184,8 @@ enum htp_profiler_mode {
struct htp_prof_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t usecs; // Number of usec
uint32_t cycles; // Number of cycles
uint32_t pad; // Unused
uint32_t cycles_start; // Start cycle counter
uint32_t cycles_stop; // Stop cycle counter
uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters
};
@@ -168,7 +194,7 @@ struct htp_opbatch_req {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of ops
uint32_t flags; // unused
uint32_t n_traces; // Number of trace descriptors per thread
uint32_t pad; // unused
// struct htp_buf_desc bufs[]; -- dspqueue buf 0
// struct htp_tensor tensors[]; -- dspqueue buf 0
@@ -181,7 +207,8 @@ struct htp_opbatch_rsp {
uint32_t n_bufs; // Number of buffers
uint32_t n_tensors; // Number of tensors
uint32_t n_ops; // Number of op profile descriptors
uint32_t pad; // unused
uint32_t n_traces[HTP_MAX_NTHREADS + 1];
uint8_t pad[8]; // align to 8 bytes
// struct htp_prof_desc profs[]; -- dspqueue buf 0
};
+41 -9
View File
@@ -400,7 +400,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->hmx_queue = NULL;
if (use_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (!ctx->hmx_queue) {
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
} else {
FARF(ERROR, "hmx-queue-create failed");
ctx->hmx_enabled = false;
}
@@ -425,6 +427,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->n_threads = n_hvx;
for (int i = 0; i < ctx->n_threads; i++) {
ctx->dma[i] = dma_queue_create(256); // queue depth
if (ctx->dma[i]) {
ctx->dma[i]->trace = &ctx->trace[i];
}
}
ctx->ddr_spad_size = 512 * 1024; // 512 KB
@@ -502,7 +507,8 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) {
struct profile_data {
uint64_t usecs;
uint64_t cycles;
uint64_t cycles_start;
uint64_t cycles_stop;
uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS];
};
@@ -512,8 +518,9 @@ static inline void profile_start(uint32_t mode, struct profile_data * d) {
hex_get_pmu(d->pmu_counters);
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_get_qtimer_count();
d->cycles = hex_get_cycles();
d->cycles_start = hex_get_cycles();
break;
default:
break;
@@ -530,8 +537,9 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
}
// fallthrough
case HTP_PROF_BASIC:
case HTP_PROF_TRACE:
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
d->cycles = hex_get_cycles() - d->cycles;
d->cycles_stop = hex_get_cycles();
break;
default:
break;
@@ -845,14 +853,15 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
const uint32_t t_size = sizeof(struct htp_tensor) * n_tens;
const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops;
const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops;
const uint32_t tr_size = (HTP_MAX_NTHREADS + 1) * req.n_traces * sizeof(struct htp_trace_desc);
if (dbuf.size < b_size + t_size + o_size + p_size) {
FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size);
if (dbuf.size < b_size + t_size + o_size + p_size + tr_size) {
FARF(ERROR, "invalid opbatch memory block size %u (req %u)", dbuf.size, b_size + t_size + o_size + p_size + tr_size);
break;
}
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size);
FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u n-traces %u : m-size %u b-size %u t-size %u o-size %u", req.id,
n_bufs, n_tens, n_ops, req.n_traces, dbuf.size, b_size, t_size, o_size);
// Setup descriptor pointers
uint8_t * m_ptr = dbuf.ptr;
@@ -869,6 +878,20 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
octx->n_threads = ctx->n_threads;
octx->ctx = ctx;
if (ctx->profiler == HTP_PROF_TRACE) {
memset(ctx->trace, 0, sizeof(ctx->trace));
struct htp_trace_desc * trace_events = (struct htp_trace_desc *) (m_ptr + p_size);
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = &trace_events[t * req.n_traces];
ctx->trace[t].max_events = req.n_traces;
}
} else {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
ctx->trace[t].events = NULL;
ctx->trace[t].max_events = 0;
}
}
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
@@ -886,7 +909,8 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
pds[i].cycles = prof.cycles;
pds[i].cycles_start = prof.cycles_start;
pds[i].cycles_stop = prof.cycles_stop;
for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) {
pds[i].pmu[j] = prof.pmu_counters[j];
}
@@ -899,6 +923,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
memset(rsp.pad, 0, sizeof(rsp.pad));
if (ctx->profiler == HTP_PROF_TRACE) {
for (int t = 0; t <= HTP_MAX_NTHREADS; t++) {
rsp.n_traces[t] = ctx->trace[t].count;
}
} else {
memset(rsp.n_traces, 0, sizeof(rsp.n_traces));
}
dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT;
+46
View File
@@ -3350,6 +3350,7 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void *
static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
@@ -3411,10 +3412,12 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0);
}
}
}
@@ -3430,6 +3433,7 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
@@ -3477,6 +3481,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Process src1 columns in pairs (2×2 tiling)
uint32_t ir1 = 0;
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
@@ -3494,6 +3500,8 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
@@ -3511,12 +3519,14 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
t2 = HAP_perf_get_qtimer_count();
@@ -3530,6 +3540,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
// q8x4x2 src1 tensor is already in VTCM spad
static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const uint32_t src0_nrows = ne01;
@@ -3581,7 +3592,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3599,7 +3612,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 2);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 2;
}
if (ir0 < src0_end_row) {
@@ -3607,7 +3622,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
ir0 += 1;
}
} else {
@@ -3627,7 +3644,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3645,7 +3664,9 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3669,6 +3690,7 @@ struct mmid_row_mapping {
// src1 tensor is already in VTCM spad
static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3735,6 +3757,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3746,6 +3769,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3764,6 +3788,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
for (uint32_t cid = 0; cid < cne1; ++cid) {
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid);
const int rm1 = row_mapping.i1; // expert idx
@@ -3775,6 +3800,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
}
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -3789,6 +3815,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
// src1 tensor is already in VTCM spad
static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
htp_matmul_preamble;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * restrict ids = octx->src[2];
struct htp_spad * restrict src2_spad = &octx->src2_spad;
@@ -3847,7 +3874,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
// Prefetch next (n + spad_nrows) row
const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -3865,7 +3894,9 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
src0_row_size_padded, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0);
}
}
@@ -4147,6 +4178,7 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui
static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4163,6 +4195,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4189,6 +4222,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data)
FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
@@ -4219,6 +4253,7 @@ static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y,
static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4235,6 +4270,7 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = src->nb[1];
@@ -4260,11 +4296,13 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat
FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4281,6 +4319,7 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4301,11 +4340,13 @@ static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4322,6 +4363,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4342,12 +4384,14 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
// TODO just a plain copy that should be done via the DMA during the Op setup
static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
struct htp_matmul_context * mmctx = data;
struct htp_ops_context * octx = mmctx->octx;
struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL;
const struct htp_tensor * src = octx->src[1];
uint8_t * restrict dst = octx->src1_spad.data;
@@ -4364,6 +4408,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
const uint32_t ir_first = nrows_per_thread * ith; // first row
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
const size_t src_row_size = ne0 * sizeof(float);
@@ -4384,6 +4429,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first);
}
+10 -9
View File
@@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) {
// transposed into VTCM.
//
// VTCM layouts (per thread):
// src1_T : {d_inner_per_thread, d_conv} staged once per launch (small).
// src0_T : {d_inner_tile, ncs} staged per d_inner-tile.
// src1_T : {d_inner_stride, d_conv} - staged once per launch (small).
// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile.
//
// d_inner_tile is chosen so that per-thread VTCM stays under the budget.
// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially.
#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM)
// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM)
static inline void transpose_src1(const float * src1_data,
uint32_t src1_stride_inner,
uint32_t i1_off,
uint32_t d_inner_per_thread,
uint32_t d_inner_stride,
uint32_t d_conv,
float * src1_T) {
for (uint32_t i = 0; i < d_inner_per_thread; ++i) {
const float * src_row = src1_data + (i1_off + i) * src1_stride_inner;
for (uint32_t j = 0; j < d_conv; ++j) {
src1_T[j * d_inner_per_thread + i] = src_row[j];
src1_T[j * d_inner_stride + i] = src_row[j];
}
}
}
@@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
}
const uint32_t d_inner_per_thread = ir1 - ir0;
const uint32_t d_inner_stride = scctx->nrows_per_thread;
const uint32_t d_inner_tile = scctx->d_inner_tile;
const float * src0_data = (const float *) src0->data;
@@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread);
float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread);
// Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T);
// Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout.
transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T);
const uint32_t C_TILE = VLEN_FP32;
@@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
HVX_Vector acc = hvx_vec_splat_f32(0.0f);
for (uint32_t j = 0; j < d_conv; ++j) {
HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb);
HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb);
acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w));
}
HVX_Vector res = Q6_Vsf_equals_Vqf32(acc);
@@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) {
use_hvx = 1;
}
scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads;
scctx.nrows_per_thread += (scctx.nrows_per_thread & 1);
scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32);
const uint32_t d_inner_per_thread = scctx.nrows_per_thread;
const uint32_t ncs = src0->ne[0];
+119 -51
View File
@@ -24,62 +24,119 @@ if (GGML_METAL_NDEBUG)
endif()
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
set(METALLIB_KERNELS_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/kernels/common.h")
set(METALLIB_KERNELS_DEQUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/dequantize.h")
set(METALLIB_KERNELS_QUANTIZE "${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantize.h")
set(METALLIB_KERNEL_SOURCES
kernels/fa.metal
kernels/mul_mv.metal
kernels/mul_mm.metal
kernels/quantize.metal
kernels/softmax.metal
kernels/norm.metal
kernels/unary.metal
kernels/binbcast.metal
kernels/reduce.metal
kernels/tri.metal
kernels/ssm.metal
kernels/wkv.metal
kernels/gated_delta_net.metal
kernels/solve_tri.metal
kernels/rope.metal
kernels/conv.metal
kernels/upscale.metal
kernels/argsort.metal
kernels/pool.metal
kernels/misc.metal
)
if (GGML_METAL_EMBED_LIBRARY)
enable_language(ASM)
add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
# merge ggml-common.h and ggml-metal.metal into a single file
set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
set(METALLIB_EMBED_ASM_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(kind ${src} NAME_WE)
# symbol names must be valid C identifiers ('-' is not allowed)
string(REPLACE "-" "_" kind_sym ${kind})
add_custom_command(
OUTPUT "${METALLIB_EMBED_ASM}"
COMMAND echo "Embedding Metal library"
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
COMMENT "Generate assembly for embedded Metal library"
VERBATIM
)
set(SRC "${CMAKE_CURRENT_SOURCE_DIR}/kernels/${kind}.metal")
set(EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.metal")
set(ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed-${kind}.s")
target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
# only prepend headers that this source actually includes
set(HEADERS_FOR_SRC ${METALLIB_KERNELS_COMMON})
file(STRINGS ${SRC} _has_dequantize REGEX "#include \"dequantize\\.h\"")
file(STRINGS ${SRC} _has_quantize REGEX "#include \"quantize\\.h\"")
if(_has_dequantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_DEQUANTIZE})
endif()
if(_has_quantize)
list(APPEND HEADERS_FOR_SRC ${METALLIB_KERNELS_QUANTIZE})
endif()
add_custom_command(
OUTPUT "${ASM}"
# Step 1: concatenate shared headers + this kernel source
COMMAND cat ${HEADERS_FOR_SRC} ${SRC} > "${EMBED}.tmp1"
# Step 2: remove internal #include and #pragma once
COMMAND sed -e "/\#include \"common.h\"/d" -e "/\#include \"dequantize.h\"/d" -e "/\#include \"quantize.h\"/d" -e "/\#pragma once/d" < "${EMBED}.tmp1" > "${EMBED}.tmp2"
# Step 3: inline ggml-common.h (replacing __embed_ggml-common.h__ sentinel)
COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${EMBED}.tmp2" > "${EMBED}.tmp3"
# Step 4: inline ggml-metal-impl.h
COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${EMBED}.tmp3" > "${EMBED}"
# Step 5: emit an asm chunk with kind-specific start/end symbols
# note: '-' is illegal in C symbols, so we use kind_sym; the macOS
# section name is limited to 16 chars so we keep it shared
# across kinds (__ggml_metallib) and only vary the global symbols.
COMMAND echo ".section __DATA,__ggml_metallib" > "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_start" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_start:" >> "${ASM}"
COMMAND echo .incbin "\"${EMBED}\"" >> "${ASM}"
COMMAND echo ".globl _ggml_metallib_${kind_sym}_end" >> "${ASM}"
COMMAND echo "_ggml_metallib_${kind_sym}_end:" >> "${ASM}"
DEPENDS ../ggml-common.h ggml-metal-impl.h
kernels/common.h kernels/dequantize.h kernels/quantize.h
kernels/${kind}.metal
COMMENT "Generate embedded Metal library for ${kind}"
VERBATIM
)
list(APPEND METALLIB_EMBED_ASM_FILES "${ASM}")
endforeach()
target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM_FILES})
else()
# copy metal files to bin directory
# copy header files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
file(MAKE_DIRECTORY "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels")
configure_file(kernels/common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/common.h COPYONLY)
configure_file(kernels/dequantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/dequantize.h COPYONLY)
configure_file(kernels/quantize.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels/quantize.h COPYONLY)
foreach(src ${METALLIB_KERNEL_SOURCES})
configure_file(${src} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} COPYONLY)
endforeach()
if (GGML_METAL_SHADER_DEBUG)
# custom command to do the following:
# xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
# xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
#
# note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
# disabling fast math is needed in order to pass tests/test-backend-ops
# note: disabling fast math is needed in order to pass tests/test-backend-ops
# note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
# note: unfortunately, we have to call it default.metallib instead of ggml.metallib
# ref: https://github.com/ggml-org/whisper.cpp/issues/1720
# note: adding -g causes segmentation fault during compile
#set(XC_FLAGS -fno-fast-math -fno-inline -g)
set(XC_FLAGS -fno-fast-math -fno-inline)
else()
set(XC_FLAGS -O3)
endif()
# Append macOS metal versioning flags
if (GGML_METAL_MACOSX_VERSION_MIN)
message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
@@ -90,35 +147,46 @@ else()
list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
endif()
# Compile each kernel source to .air, then link into default.metallib
set(AIR_FILES "")
foreach(src ${METALLIB_KERNEL_SOURCES})
get_filename_component(name ${src} NAME_WE)
set(AIR "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${name}.air")
list(APPEND AIR_FILES ${AIR})
add_custom_command(
OUTPUT ${AIR}
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -I ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${src} -o ${AIR}
DEPENDS ${src} kernels/common.h kernels/dequantize.h kernels/quantize.h ${METALLIB_COMMON} ggml-metal-impl.h
COMMENT "Compiling ${src}"
VERBATIM
)
endforeach()
add_custom_command(
OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND xcrun -sdk macosx metallib ${AIR_FILES} -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
DEPENDS ggml-metal.metal ${METALLIB_COMMON}
COMMENT "Compiling Metal kernels"
)
COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h
COMMAND rm -rf ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/kernels
DEPENDS ${AIR_FILES}
COMMENT "Linking Metal kernels into default.metallib"
)
# FIXME: only add to the ggml-metal target?
add_custom_target(
ggml-metal-lib ALL
DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
)
)
endif() # GGML_METAL_EMBED_LIBRARY
if (NOT GGML_METAL_EMBED_LIBRARY)
install(
FILES src/ggml-metal/ggml-metal.metal
PERMISSIONS
OWNER_READ
OWNER_WRITE
GROUP_READ
WORLD_READ
DESTINATION ${CMAKE_INSTALL_BINDIR})
DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/kernels/
DESTINATION ${CMAKE_INSTALL_BINDIR}/kernels
FILES_MATCHING PATTERN "*.metal" PATTERN "*.h"
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
install(
FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
+20 -3
View File
@@ -66,7 +66,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml
const char * op_str = "undefined";
switch (op) {
case GGML_OP_ADD_ID: op_str = "add_id"; break;
case GGML_OP_CONCAT: op_str = "concat"; break;
default: GGML_ABORT("fatal error");
};
@@ -211,6 +210,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat(ggml_metal_library_t lib, ggml_type tsrc) {
char base[256];
char name[256];
snprintf(base, 256, "kernel_concat_%s", ggml_type_name(tsrc));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
@@ -1689,7 +1703,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ROPE);
assert(op->op == GGML_OP_ROPE || op->op == GGML_OP_ROPE_BACK);
const bool is_back = op->op == GGML_OP_ROPE_BACK;
char base[256];
char name[256];
@@ -1713,13 +1729,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
}
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
snprintf(name, 256, "%s_imrope=%d_is_back=%d", base, is_imrope ? 1 : 0, is_back ? 1 : 0);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
ggml_metal_cv_set_bool(cv, is_back, FC_ROPE + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+1
View File
@@ -115,6 +115,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_concat (ggml_metal_library_t lib, enum ggml_type tsrc);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
+440 -133
View File
@@ -94,8 +94,63 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi
return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
}
//
// MTLLibrary collection (one library per op-source, compiled separately)
//
// Single source of truth for the per-kind metal libraries. The order here
// defines the enum values and every per-kind table below, so adding a library
// is a one-line change here (plus adding its source to CMakeLists.txt).
// X(suffix, name): name is both the kernels/<name>.metal basename and the
// ggml_metallib_<name>_{start,end} embed-symbol stem.
#define GGML_METAL_LIBS \
X(FA, fa) \
X(MUL_MV, mul_mv) \
X(MUL_MM, mul_mm) \
X(QUANTIZE, quantize) \
X(SOFTMAX, softmax) \
X(NORM, norm) \
X(UNARY, unary) \
X(BINBCAST, binbcast) \
X(REDUCE, reduce) \
X(TRI, tri) \
X(SSM, ssm) \
X(WKV, wkv) \
X(GATED_DELTA_NET, gated_delta_net)\
X(SOLVE_TRI, solve_tri) \
X(ROPE, rope) \
X(CONV, conv) \
X(UPSCALE, upscale) \
X(ARGSORT, argsort) \
X(POOL, pool) \
X(MISC, misc)
enum ggml_metal_lib_kind {
#define X(e, s) GGML_METAL_LIB_##e,
GGML_METAL_LIBS
#undef X
GGML_METAL_LIB_COUNT,
};
static const char * const k_lib_names[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = #s,
GGML_METAL_LIBS
#undef X
};
struct ggml_metal_library {
id<MTLLibrary> obj;
// Per-kind compiled libraries. When single_library is true, the whole library
// (e.g. a pre-compiled default.metallib or a from-source build) lives at
// objs[0] and the remaining slots are nil.
id<MTLLibrary> objs[GGML_METAL_LIB_COUNT];
bool single_library; // true: combined library at objs[0]; false: per-kind libs in objs[*]
// Routing table: kernel function name -> objs[] index, populated from each
// compiled library's -[MTLLibrary functionNames]. The actual compiled
// libraries are the single source of truth for which library owns a kernel,
// so adding kernels later requires no manual routing maintenance.
// nil in single_library mode (everything resolves to objs[0]).
NSMutableDictionary<NSString *, NSNumber *> * fn_to_lib;
ggml_metal_device_t dev;
ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
@@ -103,160 +158,376 @@ struct ggml_metal_library {
NSLock * lock;
};
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLLibrary> library = nil;
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
// Build the fn_to_lib routing table by querying each compiled library's public
// function names. Call once after all per-kind libraries have been compiled.
static void ggml_metal_library_build_index(ggml_metal_library_t lib) {
@autoreleasepool {
NSMutableDictionary<NSString *, NSNumber *> * index = [[NSMutableDictionary alloc] init];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
for (NSString * fname in [lib->objs[kind] functionNames]) {
index[fname] = @(kind);
}
}
lib->fn_to_lib = index;
}
}
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
//
// TODO: move to a function
{
const int64_t t_start = ggml_time_us();
// Parse a `#include "name"` line. Returns the quoted name in *include_name on
// success. Whitespace-tolerant; ignores `#include <...>` (system headers).
static bool ggml_metal_library_parse_quoted_include(NSString * line, NSString ** include_name) {
NSScanner * scanner = [NSScanner scannerWithString:line];
scanner.charactersToBeSkipped = [NSCharacterSet whitespaceCharacterSet];
NSError * error = nil;
NSString * src = nil;
if (![scanner scanString:@"#" intoString:NULL] ||
![scanner scanString:@"include" intoString:NULL] ||
![scanner scanString:@"\"" intoString:NULL]) {
return false;
}
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
NSString * name = nil;
if (![scanner scanUpToString:@"\"" intoString:&name]) {
return false;
}
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
if (include_name) {
*include_name = name;
}
return true;
}
src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
// Recursively inline `#include "name"` directives. System includes (<...>),
// `#if/#else/#endif`, and other preprocessor lines are passed through to the
// Metal compiler unchanged. `#pragma once` is dropped since `seen` already
// guards against double-inclusion.
static bool ggml_metal_library_flatten_file(NSMutableString * dst, NSString * path,
NSArray<NSString *> * search_paths,
NSMutableSet<NSString *> * seen, NSError ** error) {
NSString * key = [path stringByStandardizingPath];
if ([seen containsObject:key]) {
return true;
}
[seen addObject:key];
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:error];
if (!src) {
return false;
}
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSFileManager * fm = [NSFileManager defaultManager];
for (NSString * line in [src componentsSeparatedByString:@"\n"]) {
NSString * trimmed = [line stringByTrimmingCharactersInSet:[NSCharacterSet whitespaceCharacterSet]];
if ([trimmed isEqualToString:@"#pragma once"]) {
continue;
}
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
NSString * include_name = nil;
if (ggml_metal_library_parse_quoted_include(line, &include_name)) {
NSString * resolved = nil;
for (NSString * dir in search_paths) {
NSString * candidate = [dir stringByAppendingPathComponent:include_name];
if ([fm isReadableFileAtPath:candidate]) {
resolved = candidate;
break;
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
if (!resolved) {
if (error) {
NSString * msg = [NSString stringWithFormat:@"could not resolve include \"%@\" from '%@'", include_name, path];
*error = [NSError errorWithDomain:@"ggml-metal-source-flatten" code:1
userInfo:@{NSLocalizedDescriptionKey: msg}];
}
return false;
}
if (!ggml_metal_library_flatten_file(dst, resolved, search_paths, seen, error)) {
return false;
}
continue;
}
if (path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
[dst appendString:line];
[dst appendString:@"\n"];
}
library = [device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
} else {
GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
return true;
}
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
static NSString * ggml_metal_library_flatten_source(NSString * path_source, NSError ** error) {
// Search paths cover both runtime layout (build/bin/kernels + build/bin)
// and source-tree layout (ggml/src/ggml-metal/kernels + ggml/src/ggml-metal + ggml/src).
NSString * path_kernels = [path_source stringByDeletingLastPathComponent];
NSString * path_base = [path_kernels stringByDeletingLastPathComponent];
NSArray<NSString *> * search_paths = @[
path_kernels,
path_base,
[path_base stringByDeletingLastPathComponent],
];
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
NSMutableString * src = [[NSMutableString alloc] init];
NSMutableSet<NSString *> * seen = [NSMutableSet set];
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
if (!ggml_metal_library_flatten_file(src, path_source, search_paths, seen, error)) {
[src release];
return nil;
}
return src;
}
// Compile all per-kind libraries in parallel. `source_for_kind` returns the MSL
// source for a kind (the helper takes ownership and releases it), or nil with
// *err set on failure. On success the objs[] slots are populated and the routing
// index is built; on any failure every error is logged and false is returned
// (the caller is responsible for freeing `res`).
static bool ggml_metal_library_compile_all(
ggml_metal_library_t res,
id<MTLDevice> device,
NSDictionary * prep,
NSString * (^source_for_kind)(int kind, NSError ** err),
const char * origin) {
const int64_t t_start = ggml_time_us();
int64_t * t_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(int64_t));
NSError ** err_per_lib = calloc(GGML_METAL_LIB_COUNT, sizeof(NSError *));
__block atomic_bool any_failure = false;
dispatch_group_t group = dispatch_group_create();
dispatch_queue_t queue = dispatch_get_global_queue(QOS_CLASS_USER_INITIATED, 0);
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
dispatch_group_async(group, queue, ^{
const int64_t t0 = ggml_time_us();
NSError * error = nil;
NSString * src = source_for_kind(kind, &error);
if (!src) {
err_per_lib[kind] = [error retain];
atomic_store(&any_failure, true);
return;
}
if (path_source == nil) {
GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
path_source = @"ggml-metal.metal";
}
id<MTLLibrary> lib = nil;
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
}
#endif
if (!library) {
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
MTLCompileOptions * options = [MTLCompileOptions new];
options.preprocessorMacros = prep;
//[options setFastMathEnabled:false];
lib = [device newLibraryWithSource:src options:options error:&error];
library = [device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return nil;
}
#if !__has_feature(objc_arc)
[options release];
#endif
// retain the error before the autorelease pool drains it
if (!lib) {
err_per_lib[kind] = [error retain];
}
}
[src release];
t_per_lib[kind] = ggml_time_us() - t0;
if (!lib) {
atomic_store(&any_failure, true);
return;
}
res->objs[kind] = lib;
});
}
dispatch_group_wait(group, DISPATCH_TIME_FOREVER);
dispatch_release(group);
const bool ok = !atomic_load(&any_failure);
if (ok) {
const int64_t t_total = ggml_time_us() - t_start;
int64_t t_max = 0;
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
GGML_LOG_DEBUG("%s: compiled '%s' library in %.3f sec\n",
__func__, k_lib_names[kind], t_per_lib[kind] / 1e6);
if (t_per_lib[kind] > t_max) t_max = t_per_lib[kind];
}
GGML_LOG_INFO("%s: loaded %d libraries from %s in %.3f sec (max single = %.3f sec)\n",
__func__, GGML_METAL_LIB_COUNT, origin, t_total / 1e6, t_max / 1e6);
ggml_metal_library_build_index(res);
} else {
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (err_per_lib[kind]) {
GGML_LOG_ERROR("%s: failed to build '%s' library: %s\n", __func__,
k_lib_names[kind], [[err_per_lib[kind] description] UTF8String]);
[err_per_lib[kind] release];
}
}
#if GGML_METAL_EMBED_LIBRARY
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
}
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
free(err_per_lib);
free(t_per_lib);
res->obj = library;
return ok;
}
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
// shared MTLCompileOptions preprocessor macros (matches the build-time defines)
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
if (ggml_metal_device_get_props(dev)->has_bfloat) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
}
if (ggml_metal_device_get_props(dev)->has_tensor) {
[prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
}
#if GGML_METAL_EMBED_LIBRARY
[prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
#endif
#if GGML_METAL_EMBED_LIBRARY
GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
// start/end symbols emitted by CMake (see CMakeLists.txt), one pair per kind
#define X(e, s) extern const char ggml_metallib_##s##_start[]; extern const char ggml_metallib_##s##_end[];
GGML_METAL_LIBS
#undef X
static const char * const lib_start[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_start,
GGML_METAL_LIBS
#undef X
};
static const char * const lib_end[GGML_METAL_LIB_COUNT] = {
#define X(e, s) [GGML_METAL_LIB_##e] = ggml_metallib_##s##_end,
GGML_METAL_LIBS
#undef X
};
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
(void) err;
return [[NSString alloc] initWithBytes:lib_start[kind]
length:(lib_end[kind] - lib_start[kind])
encoding:NSUTF8StringEncoding];
}, "embedded data");
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#else
#ifdef SWIFT_PACKAGE
NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
#else
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
const int64_t t_start = ggml_time_us();
NSError * error = nil;
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (path_lib == nil) {
// Try to find the resource in the directory where the current binary located.
NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
// Optionally, if this is a symlink, try to resolve it.
path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
// It is a relative path, adding the binary directory as directory prefix.
path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
}
if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
// Link to the resource could not be resolved.
path_lib_default = nil;
} else {
GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
}
}
} else {
// The resource couldn't be found in the binary's directory.
path_lib_default = nil;
}
path_lib = path_lib_default;
}
if (path_lib != nil) {
// pre-compiled library found: a single combined default.metallib
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
res->objs[0] = [device newLibraryWithURL:libURL error:&error];
res->single_library = true;
if (!res->objs[0]) {
GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
ggml_metal_library_free(res);
return NULL;
}
GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
return res;
}
// no pre-compiled metallib: fall back to compiling each kernel source separately
GGML_LOG_INFO("%s: default.metallib not found, loading kernel sources\n", __func__);
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
if (path_resource) {
GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, [path_resource UTF8String]);
}
// resolve each kind's source path up front (file lookup/logging stays on the calling thread)
NSString ** path_per_kind = calloc(GGML_METAL_LIB_COUNT, sizeof(NSString *));
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
NSString * rel = [NSString stringWithFormat:@"kernels/%s.metal", k_lib_names[kind]];
NSString * path_source = nil;
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:rel];
} else {
NSString * stem = [NSString stringWithFormat:@"kernels/%s", k_lib_names[kind]];
path_source = [bundle pathForResource:stem ofType:@"metal"];
}
if (path_source == nil || ![[NSFileManager defaultManager] isReadableFileAtPath:path_source]) {
GGML_LOG_WARN("%s: could not locate %s in bundle, falling back to cwd\n", __func__, [rel UTF8String]);
path_source = rel;
}
GGML_LOG_DEBUG("%s: loading '%s'\n", __func__, [path_source UTF8String]);
path_per_kind[kind] = [path_source retain];
}
const bool ok = ggml_metal_library_compile_all(res, device, prep,
^NSString * (int kind, NSError ** err) {
return ggml_metal_library_flatten_source(path_per_kind[kind], err);
}, "source");
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
[path_per_kind[kind] release];
}
free(path_per_kind);
if (!ok) {
ggml_metal_library_free(res);
return NULL;
}
return res;
#endif
}
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
@@ -318,10 +589,11 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
return NULL;
}
res->obj = library;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
res->objs[0] = library;
res->single_library = true;
res->dev = dev;
res->pipelines = ggml_metal_pipelines_init();
res->lock = [NSLock new];
return res;
}
@@ -331,8 +603,14 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
return;
}
if (lib->obj) {
[lib->obj release];
for (int kind = 0; kind < GGML_METAL_LIB_COUNT; ++kind) {
if (lib->objs[kind]) {
[lib->objs[kind] release];
}
}
if (lib->fn_to_lib) {
[lib->fn_to_lib release];
}
ggml_metal_pipelines_free(lib->pipelines);
@@ -393,11 +671,28 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
// route to the library that actually defines this kernel; fn_to_lib is
// built from -[MTLLibrary functionNames] so it's always in sync
int lib_idx = 0;
if (!lib->single_library) {
NSNumber * idx = lib->fn_to_lib[base_func];
if (!idx) {
[lib->lock unlock];
GGML_LOG_ERROR("%s: kernel not found in any metal library: base = '%s', name = '%s'\n", __func__, base, name);
return res;
}
lib_idx = [idx intValue];
}
id<MTLLibrary> mtl_lib = lib->objs[lib_idx];
id<MTLFunction> mtl_function;
if (!cv) {
mtl_function = [lib->obj newFunctionWithName:base_func];
mtl_function = [mtl_lib newFunctionWithName:base_func];
} else {
mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
mtl_function = [mtl_lib newFunctionWithName:base_func constantValues:cv->obj error:&error];
}
if (!mtl_function) {
[lib->lock unlock];
@@ -1123,13 +1418,24 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return true;
case GGML_OP_CONCAT:
{
// kernel_concat copies one float-sized value per element.
// Other scalar types need a type-generic copy kernel first.
const enum ggml_type src0_type = op->src[0]->type;
const enum ggml_type src1_type = op->src[1]->type;
return src0_type == src1_type &&
src0_type == op->type &&
(src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32);
if (src0_type != src1_type || src0_type != op->type) {
return false;
}
switch (src0_type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
return true;
case GGML_TYPE_BF16:
return has_bfloat;
default:
return false;
}
}
case GGML_OP_ADD:
case GGML_OP_SUB:
@@ -1173,6 +1479,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
return true;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+2 -1
View File
@@ -375,6 +375,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_norm(ctx, idx);
} break;
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
{
n_fuse = ggml_metal_op_rope(ctx, idx);
} break;
@@ -556,7 +557,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
/*.dim =*/ dim,
};
auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
auto pipeline = ggml_metal_library_get_pipeline_concat(lib, op->type);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
File diff suppressed because it is too large Load Diff
+232
View File
@@ -0,0 +1,232 @@
#include "common.h"
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_f32_i32(
constant ggml_metal_kargs_argsort & args,
device const char * src0,
device int32_t * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
// bitonic sort
const int col = tpitg[0];
const int ib = tgpig[0] / args.ne01;
const int i00 = ib*ntg.x;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
// initialize indices
shmem_i32[col] = i00 + col;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int k = 2; k <= ntg.x; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (shmem_i32[col] >= args.ne00 ||
(shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
} else {
if (shmem_i32[ixj] >= args.ne00 ||
(shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
) {
SWAP(shmem_i32[col], shmem_i32[ixj]);
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
const int64_t i0 = ib*args.top_k;
// copy the result to dst without the padding
if (i0 + col < args.ne0 && col < args.top_k) {
dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
dst[col] = shmem_i32[col];
}
}
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
typedef void (argsort_merge_t)(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]);
template<ggml_sort_order order>
kernel void kernel_argsort_merge_f32_i32(
constant ggml_metal_kargs_argsort_merge & args,
device const char * src0,
device const int32_t * tmp,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int im = tgpig[0] / args.ne01;
const int i01 = tgpig[0] % args.ne01;
const int i02 = tgpig[1];
const int i03 = tgpig[2];
const int start = im * (2 * args.len);
const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
const int total = len0 + len1;
device const int32_t * tmp0 = tmp + start
+ i01*args.ne0
+ i02*args.ne0*args.ne01
+ i03*args.ne0*args.ne01*args.ne02;
device const int32_t * tmp1 = tmp0 + args.len;
dst += start
+ i01*args.top_k
+ i02*args.top_k*args.ne01
+ i03*args.top_k*args.ne01*args.ne02;
device const float * src0_row = (device const float *)(src0
+ args.nb01*i01
+ args.nb02*i02
+ args.nb03*i03);
if (total == 0) {
return;
}
const int chunk = (total + ntg.x - 1) / ntg.x;
const int k0 = tpitg.x * chunk;
const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
if (k0 >= args.top_k) {
return;
}
if (k0 >= total) {
return;
}
int low = k0 > len1 ? k0 - len1 : 0;
int high = MIN(k0, len0);
// binary-search partition (i, j) such that i + j = k
while (low < high) {
const int mid = (low + high) >> 1;
const int32_t idx0 = tmp0[mid];
const int32_t idx1 = tmp1[k0 - mid - 1];
const float val0 = src0_row[idx0];
const float val1 = src0_row[idx1];
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
low = mid + 1;
} else {
high = mid;
}
}
int i = low;
int j = k0 - i;
// keep the merge fronts into registers
int32_t idx0 = 0;
float val0 = 0.0f;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
int32_t idx1 = 0;
float val1 = 0.0f;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
for (int k = k0; k < k1; ++k) {
int32_t out_idx;
if (i >= len0) {
while (k < k1) {
dst[k++] = tmp1[j++];
}
break;
} else if (j >= len1) {
while (k < k1) {
dst[k++] = tmp0[i++];
}
break;
} else {
bool take_left;
if (order == GGML_SORT_ORDER_ASC) {
take_left = (val0 <= val1);
} else {
take_left = (val0 >= val1);
}
if (take_left) {
out_idx = idx0;
++i;
if (i < len0) {
idx0 = tmp0[i];
val0 = src0_row[idx0];
}
} else {
out_idx = idx1;
++j;
if (j < len1) {
idx1 = tmp1[j];
val1 = src0_row[idx1];
}
}
}
dst[k] = out_idx;
}
}
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
+226
View File
@@ -0,0 +1,226 @@
#include "common.h"
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]];
template <typename T0, typename T1, typename T>
kernel void kernel_bin_fuse_impl(
constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
#define FC_OP FC_bin_op
#define FC_F FC_bin_f
#define FC_RB FC_bin_rb
#define FC_CB FC_bin_cb
if (FC_RB) {
// row broadcast
const uint i0 = tgpig.y*args.ne00 + tgpig.x;
const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
device const T0 * src0_row = (device const T0 *) (src0);
device T * dst_row = (device T *) (dst);
if (FC_F == 1) {
device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
if (FC_OP == 0) {
dst_row[i0] = src0_row[i0] + src1_row[i1];
}
if (FC_OP == 1) {
dst_row[i0] = src0_row[i0] - src1_row[i1];
}
if (FC_OP == 2) {
dst_row[i0] = src0_row[i0] * src1_row[i1];
}
if (FC_OP == 3) {
dst_row[i0] = src0_row[i0] / src1_row[i1];
}
} else {
T0 res = src0_row[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
}
}
dst_row[i0] = res;
}
} else {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (i01 >= args.ne01) {
return;
}
const int i13 = i03%args.ne13;
const int i12 = i02%args.ne12;
const int i11 = i01%args.ne11;
device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
if (FC_F == 1) {
device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
if (FC_OP == 0) {
dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
}
if (FC_OP == 1) {
dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
}
if (FC_OP == 2) {
dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
}
if (FC_OP == 3) {
dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
}
}
} else {
device const T1 * src1_ptr[8];
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
}
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i10 = FC_CB ? i0%args.ne10 : i0;
T res = src0_ptr[i0];
if (FC_OP == 0) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res += src1_ptr[j][i10];
}
}
if (FC_OP == 1) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res -= src1_ptr[j][i10];
}
}
if (FC_OP == 2) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res *= src1_ptr[j][i10];
}
}
if (FC_OP == 3) {
FOR_UNROLL (short j = 0; j < FC_F; ++j) {
res /= src1_ptr[j][i10];
}
}
dst_ptr[i0] = res;
}
}
}
#undef FC_OP
#undef FC_F
#undef FC_RB
#undef FC_CB
}
typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
kernel void kernel_add_id(
constant ggml_metal_kargs_add_id & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i1 = tgpig.x;
const int i2 = tgpig.y;
const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
const size_t nb1 = args.ne0 * sizeof(float);
const size_t nb2 = args.ne1 * nb1;
device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
const int i03 = i3%args.ne03;
const int i02 = i2%args.ne02;
const int i01 = i1%args.ne01;
device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
const int i00 = i0%args.ne00;
*((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_repeat_bf16")]] kernel kernel_repeat_t kernel_repeat<bfloat>;
#endif
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
+126
View File
@@ -0,0 +1,126 @@
#pragma once
#include "ggml-metal-impl.h"
#include <metal_stdlib>
#ifdef GGML_METAL_HAS_TENSOR
#include <metal_tensor>
#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
#endif
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// cmd:
// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/kernels/<src>.metal
// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/kernels/<src>.metal
//
#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
#undef GGML_METAL_HAS_BF16
#endif
#if defined(GGML_METAL_HAS_BF16)
typedef matrix<bfloat, 4, 4> bfloat4x4;
typedef matrix<bfloat, 2, 4> bfloat2x4;
#endif
constexpr constant static float kvalues_iq4nl_f[16] = {
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static inline float e8m0_to_fp32(uint8_t x) {
uint32_t bits;
if (x == 0) {
bits = 0x00400000;
} else {
bits = (uint32_t) x << 23;
}
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
GGML_SORT_ORDER_DESC,
};
constant float GELU_COEF_A = 0.044715f;
constant float GELU_QUICK_COEF = -1.702f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
constant float p_erf = 0.3275911f;
constant float a1_erf = 0.254829592f;
constant float a2_erf = -0.284496736f;
constant float a3_erf = 1.421413741f;
constant float a4_erf = -1.453152027f;
constant float a5_erf = 1.061405429f;
template<typename T>
inline T erf_approx(T x) {
T sign_x = sign(x);
x = fabs(x);
T t = 1.0f / (1.0f + p_erf * x);
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return sign_x * y;
}
template<typename T> T elu_approx(T x);
template<> inline float elu_approx<float>(float x) {
return (x > 0.f) ? x : (exp(x) - 1);
}
template<> inline float4 elu_approx<float4>(float4 x) {
float4 res;
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
return res;
}
+485
View File
@@ -0,0 +1,485 @@
#include "common.h"
typedef void (im2col_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
// const int64_t IC = tgpg[0];
const int64_t OH = tgpg[1];
const int64_t OW = tgpg[2];
const int64_t KH = ntg[1];
const int64_t KW = ntg[2];
int64_t in = tpitg[0];
const int64_t ikh = tpitg[1];
const int64_t ikw = tpitg[2];
const int64_t iic = tgpig[0];
const int64_t ioh = tgpig[1];
const int64_t iow = tgpig[2];
const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
while (in < args.N) {
pdst[offset_dst] = 0.0f;
offset_dst += ntg[0]*args.CHW*OH*OW;
in += ntg[0];
}
} else {
int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
while (in < args.N) {
pdst[offset_dst] = x[offset_src];
offset_dst += ntg[0]*args.CHW*OH*OW;
offset_src += ntg[0]*args.ofs0;
in += ntg[0];
}
}
}
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
// TODO: optimize
typedef void (im2col_ext_t)(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_im2col_ext(
constant ggml_metal_kargs_im2col & args,
device const float * x,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
const int64_t KHW = (int64_t)args.KHW;
const int64_t d = tgpig[0] / args.CHW;
const int64_t chw = tgpig[0] % args.CHW;
const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
const int64_t HW = tgpig[0] % KHW;
const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
if (tpitg_0 >= args.N) {
return;
}
const int64_t tpitg_1 = HW / args.KW;
const int64_t tpitg_2 = HW % args.KW;
const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
const int64_t offset_dst =
(tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
(tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
device T * pdst = (device T *) (dst);
if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
pdst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
}
}
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
template <typename TK>
kernel void kernel_conv_2d(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
const uint thread_index = tg_index * threads_per_tg + local_thread;
const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
uint64_t tmp = index;
const int32_t ow = tmp % args.OW; tmp /= args.OW;
const int32_t oh = tmp % args.OH; tmp /= args.OH;
const int32_t oc = tmp % args.OC; tmp /= args.OC;
const int32_t n = tmp;
float acc = 0.0f;
const int32_t base_x = ow*args.s0 - args.p0;
const int32_t base_y = oh*args.s1 - args.p1;
int32_t ky_start = 0;
if (base_y < 0) {
ky_start = (-base_y + args.d1 - 1)/args.d1;
}
int32_t ky_end = args.KH;
const int32_t y_max = args.IH - 1 - base_y;
if (y_max < 0) {
ky_end = ky_start;
} else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
ky_end = min(ky_end, y_max/args.d1 + 1);
}
int32_t kx_start = 0;
if (base_x < 0) {
kx_start = (-base_x + args.d0 - 1)/args.d0;
}
int32_t kx_end = args.KW;
const int32_t x_max = args.IW - 1 - base_x;
if (x_max < 0) {
kx_end = kx_start;
} else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
kx_end = min(kx_end, x_max/args.d0 + 1);
}
if (ky_start < ky_end && kx_start < kx_end) {
const uint64_t src_base_n = (uint64_t) n * args.nb13;
const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
for (int32_t ic = 0; ic < args.IC; ++ic) {
const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
for (int32_t ky = ky_start; ky < ky_end; ++ky) {
const int32_t iy = base_y + ky*args.d1;
const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
for (int32_t kx = kx_start; kx < kx_end; ++kx) {
const int32_t ix = base_x + kx*args.d0;
const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
const float x = *(device const float *)(src + src_offs);
const float w = (float) (*(device const TK *)(weights + w_offs));
acc += x * w;
}
}
}
}
const uint64_t dst_offs =
(uint64_t) n * args.nb3 +
(uint64_t) oc * args.nb2 +
(uint64_t) oh * args.nb1 +
(uint64_t) ow * args.nb0;
*(device float *)(dst + dst_offs) = acc;
}
}
template [[host_name("kernel_conv_2d_f32_f32")]]
kernel void kernel_conv_2d<float>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_2d_f16_f32")]]
kernel void kernel_conv_2d<half>(
constant ggml_metal_kargs_conv_2d & args,
device const char * weights,
device const char * src,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
typedef void (conv_transpose_1d_t)(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_1d(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const T * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]) {
// For output position j on the time axis, only input positions
// i such that i*s0 <= j < i*s0 + K
// contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)]
// intersected with [0, IL-1]. That's at most ceil(K/s0) values
// (typically 2 for stride==K/2 transposed convs).
const int32_t j = tgpig[0];
const int32_t s0 = args.s0;
const int32_t K = args.K;
const int32_t IL = args.IL;
int32_t i_min;
{
int32_t a = j - K + 1;
i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0
}
int32_t i_max = j / s0;
if (i_max > IL - 1) i_max = IL - 1;
float v = 0.0f;
if (i_min <= i_max) {
for (int64_t c = 0; c < args.IC; c++) {
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
const int32_t input_offset = c * IL;
for (int32_t i = i_min; i <= i_max; i++) {
v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i];
}
}
}
device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
dst_ptr[0] = v;
}
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
kernel void kernel_conv_transpose_1d<float>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
kernel void kernel_conv_transpose_1d<half>(
constant ggml_metal_kargs_conv_transpose_1d & args,
device const half * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template <typename T>
kernel void kernel_conv_3d(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0, // Weights [IC * OC, KD, KH, KW]
device const char * src1, // Inputs [IC * N, ID, IH, IW]
device char * dst, // Outputs [OC * N, OD, OH, OW]
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// 1. Un-flatten the spatial dimension from Grid X
int64_t spatial_idx = tgpig.x * 32 + tpitg.x;
if (spatial_idx >= args.OW * args.OH * args.OD) {
return; // Thread falls outside the spatial volume
}
int64_t od = spatial_idx / (args.OW * args.OH);
int64_t oh = (spatial_idx / args.OW) % args.OH;
int64_t ow = spatial_idx % args.OW;
// 2. Map Y to Channels, Z to Batch
int64_t oc = tgpig.y;
int64_t batch_idx = tgpig.z;
// 3. Calculate anchor coordinates in the Input volume
int64_t i_w_base = ow * args.s0 - args.p0;
int64_t i_h_base = oh * args.s1 - args.p1;
int64_t i_d_base = od * args.s2 - args.p2;
float sum = 0.0f;
// 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width)
for (int64_t ic = 0; ic < args.IC; ++ic) {
// ggml packs batch and channel together in the 4th dimension
int64_t src_cn_idx = batch_idx * args.IC + ic;
int64_t w_cn_idx = oc * args.IC + ic;
for (int64_t kz = 0; kz < args.KD; ++kz) {
int64_t id = i_d_base + kz * args.d2;
if (id < 0 || id >= args.ID) continue; // Boundary check (Padding)
for (int64_t ky = 0; ky < args.KH; ++ky) {
int64_t ih = i_h_base + ky * args.d1;
if (ih < 0 || ih >= args.IH) continue;
for (int64_t kx = 0; kx < args.KW; ++kx) {
int64_t iw = i_w_base + kx * args.d0;
if (iw < 0 || iw >= args.IW) continue;
// Convert multi-dimensional coordinates to flat byte offsets
int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03;
int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13;
// Dereference memory and cast weights to f32 if they were f16
float w_val = (float)*(device const T*)((device const char*)src0 + w_idx);
float i_val = *(device const float*)((device const char*)src1 + i_idx);
sum += w_val * i_val;
}
}
}
}
// 5. Write the accumulated value out to RAM
int64_t dst_cn_idx = batch_idx * args.OC + oc;
int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3;
*(device float*)(dst + d_idx) = sum;
}
// Explicit instantiations so the JIT compiler can find them by name
template [[host_name("kernel_conv_3d_f32_f32")]]
kernel void kernel_conv_3d<float>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
// Explicit instantiation for f16 weights
template [[host_name("kernel_conv_3d_f16_f32")]]
kernel void kernel_conv_3d<half>(
constant ggml_metal_kargs_conv_3d & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
+686
View File
@@ -0,0 +1,686 @@
#pragma once
#include "common.h"
#define GGML_COMMON_DECL_METAL
#define GGML_COMMON_IMPL_METAL
#if defined(GGML_METAL_EMBED_LIBRARY)
__embed_ggml-common.h__
#else
#include "ggml-common.h"
#endif
#define QK_NL 16 // shared by mul_mm and get_rows_q instantiations
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
reg = (type4)(*src);
}
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#if defined(GGML_METAL_HAS_BF16)
template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
}
template <typename type4>
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src));
}
#endif
template <typename type4x4>
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs;
const float d = xb->d;
const float neg_d = -d;
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
const uint8_t b0 = qs[byte_offset];
const uint8_t b1 = qs[byte_offset + 1];
float4x4 reg_f;
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
const float d = xb->d;
const float neg_d = -d;
const int base = il * 4;
const uint8_t byte = xb->qs[base / 8];
const int s = base % 8;
float4 reg_f;
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
reg = (type4) reg_f;
}
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
const float d = xb->d;
const float md = -16.h * xb->d;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + md;
reg[2*ii + 1] = d * x1 + md;
}
}
template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = il ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = il ? 4 : 0;
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
float4x4 reg_f;
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
const float d = xb->d;
const float m = xb->m;
const ushort mask = (il/4) ? 0x00F0 : 0x000F;
const uint32_t qh = *((device const uint32_t *)xb->qh);
const int x_mv = (il/4) ? 4 : 0;
const int gh_mv = (il/4) ? 12 : 0;
const int gh_bk = (il/4) ? 0 : 4;
for (int ii = 0; ii < 2; ii++) {
int i = 2*(il%4) + ii;
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
// combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
reg[2*ii + 0] = d * x0 + m;
reg[2*ii + 1] = d * x1 + m;
}
}
template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
float4x4 reg_f;
for (int i = 0; i < 16; i++) {
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4>
void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const uint8_t shr = il >= 1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
}
template <typename type4>
void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
device const uint8_t * q2 = (device const uint8_t *)xb->qs;
const float d = e8m0_to_fp32(xb->e);
const short il4 = il%4;
const uint8_t shr = il >= 4 ? 4 : 0;
reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
float dl, ml;
uint8_t sc = xb->scales[il];
q = q + 32*(il/8) + 16*(il&1);
il = (il/2)%4;
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
q = q + 32 * (il/8) + 16 * (il&1);
h = h + 16 * (il&1);
uint8_t m = 1 << (il/2);
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
device const uchar * q = xb->qs;
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.h;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il < 2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uint8_t * qh = xb->qh;
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float d = il < 2 ? xb->d : xb->d / 16.f;
const float min = xb->dmin;
const float dl = d * sc[0];
const float ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const half d_all = xb->d;
device const uint16_t * ql = (device const uint16_t *)xb->ql;
device const uint16_t * qh = (device const uint16_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
qh = qh + 16*(il/8) + 8*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
const float ml = d_all * sc * 32.f;
const float dl0 = d_all * sc;
const float dl1 = dl0 / 256.f;
const float dl2 = dl0 / (256.f * 256.f);
const float dl3 = dl0 / (256.f * 256.f * 256.f);
const uint8_t shr_h = il>2 ? 2 : 0;
const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
const uint8_t shr_l = il>1 ? 4 : 0;
for (int i = 0; i < 4; ++i) {
const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
}
}
template <typename type4x4>
void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
// each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
device const uint16_t * q2 = xb->qs + 4*ib32;
const uint32_t aux32_g = q2[0] | (q2[1] << 16);
const uint32_t aux32_s = q2[2] | (q2[3] << 16);
thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint16_t * q2 = xb->qs + 4*ib32;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
for (int i = 0; i < 8; ++i) {
reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
signs = ksigns_iq2xs[q2[2*il+1] >> 9];
for (int i = 0; i < 8; ++i) {
reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * q3 = xb->qs + 8*ib32;
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
}
}
template <typename type4x4>
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 8*ib32;
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
}
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
for (int i = 0; i < 4; ++i) {
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
}
}
template <typename type4x4>
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const float d = xb->d;
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * signs = qs + QK_K/8;
const uint8_t qh = xb->qh[ib32] >> 4*il;
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
for (int i = 0; i < 8; ++i) {
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
}
}
template <typename type4x4>
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
const float d = xb->d;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint16_t * qh = xb->qh;
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
const uint16_t h = qh[ib32] >> 6*il;
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
reg[1][i] = dl * (grid1[i] >> 4) + ml;
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
reg[3][i] = dl * (grid2[i] >> 4) + ml;
}
}
template <typename type4x4>
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
device const uint16_t * sc = (device const uint16_t *)xb->scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const float d = scale.f16;
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
device const uint8_t * qh = xb->qh + 2*ib32 + il;
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
for (int i = 0; i < 4; ++i) {
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
}
}
template <typename type4x4>
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
template <typename type4>
void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
const float d = xb->d;
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
reg[0] = d * kvalues_iq4nl_f[q8[0]];
reg[1] = d * kvalues_iq4nl_f[q8[1]];
reg[2] = d * kvalues_iq4nl_f[q8[2]];
reg[3] = d * kvalues_iq4nl_f[q8[3]];
}
template <typename type4x4>
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
const int ib32 = il/2;
il = il%2;
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)xb->d * (ls - 32);
uint32_t aux32;
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
for (int i = 0; i < 4; ++i) {
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,250 @@
#include "common.h"
constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]];
#if 1
template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
#define K FC_gated_delta_net_K
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B (n_seqs)
const uint i21 = tgpig.y; // H (head)
const uint i20 = tgpig.x*NSG + ty; // row within S_v
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
// input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D.
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
device const float * s_ptr = (device const float *) (s) + state_in_base;
float ls[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is];
}
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
// When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned.
// output state base offset: after attention scores
const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23;
// output state per-slot size: S_v * S_v * H * n_seqs
const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23;
// per-(seq,head) offset within a slot
const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
for (short t = 0; t < args.ne22; t++) {
float s_k = 0.0f;
if (G == 1) {
const float g_exp = exp(g_ptr[0]);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;
s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);
s_k += ls[j]*k_ptr[is];
}
}
s_k = simd_sum(s_k);
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
float y = 0.0f;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;
y += ls[j]*q_ptr[is];
}
y = simd_sum(y);
if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
if (K > 1) {
const int target_slot = (int)args.ne22 - 1 - (int)t;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
}
}
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is] = ls[j];
}
}
#undef S_v
#undef G
#undef K
}
typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
#else
// a simplified version of the above
// no performance improvement, so keep the above version for now
template<typename T, short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30
const uint tx = tpitg.x;
const uint ty = tpitg.y;
const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;
const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;
const float scale = 1.0f / sqrt((float)S_v);
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
float lsf[NSG];
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
lsf[j] = s_ptr[is*S_v];
}
thread T * ls = (thread T *) (lsf);
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
for (short t = 0; t < args.ne22; t++) {
device const T * qt_ptr = (device const T *) (q_ptr);
device const T * kt_ptr = (device const T *) (k_ptr);
device const T * gt_ptr = (device const T *) (g_ptr);
if (G == 1) {
*ls *= exp(g_ptr[0]);
} else {
// KDA
*ls *= exp(gt_ptr[tx]);
}
const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
const float d = (v_ptr[i20] - s_k)*b_ptr[0];
*ls += kt_ptr[tx]*d;
const float y = simd_sum(dot(*ls, qt_ptr[tx]));
if (tx == 0) {
*dst_attn = y*scale;
}
q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;
b_ptr += args.ne21;
g_ptr += args.ne21*G;
dst_attn += args.ne21*S_v;
}
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
device T * dstt_state = (device T *) (dst_state);
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = lsf[j];
}
#undef S_v
#undef G
}
typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
#endif
+347
View File
@@ -0,0 +1,347 @@
#include "common.h"
kernel void kernel_argmax_f32(
constant ggml_metal_kargs_argmax & args,
device const char * src0,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
float lmax = -INFINITY;
int32_t larg = -1;
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
if (x_row[i00] > lmax) {
lmax = x_row[i00];
larg = i00;
}
}
// find the argmax value in the block
float max_val = simd_max(lmax);
int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
device int32_t * dst_i32 = (device int32_t *) dst;
threadgroup float * shared_maxval = (threadgroup float *) shmem;
threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
shared_maxval[tiisg] = -INFINITY;
shared_argmax[tiisg] = -1;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shared_maxval[sgitg] = max_val;
shared_argmax[sgitg] = arg_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = shared_maxval[tiisg];
arg_val = shared_argmax[tiisg];
float max_val_reduced = simd_max(max_val);
int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
dst_i32[tgpig] = arg_val_reduced;
return;
}
dst_i32[tgpig] = arg_val;
}
kernel void kernel_diag_f32(
constant ggml_metal_kargs_diag & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]]) {
constexpr short NW = N_SIMDWIDTH;
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
}
}
kernel void kernel_roll_f32(
constant ggml_metal_kargs_roll & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
device const float * src0_ptr = (device const float *) src0;
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
// apply shifts and wrap around
int64_t i00 = i0 - args.s0;
int64_t i01 = i1 - args.s1;
int64_t i02 = i2 - args.s2;
int64_t i03 = i3 - args.s3;
if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; }
if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; }
if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; }
if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; }
int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00;
int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0;
dst_ptr[dst_idx] = src0_ptr[src_idx];
}
}
template <typename T>
kernel void kernel_pad_impl(
constant ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}
typedef decltype(kernel_pad_impl<float>) kernel_pad_t;
template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;
// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.p0) {
dst_ptr[i0] = src0_ptr[args.p0 - i0];
} else if (i0 < args.ne0 - args.p1) {
dst_ptr[i0] = src0_ptr[i0 - args.p0];
} else {
dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
constant ggml_metal_kargs_arange & args,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * dst_ptr = (device float *) dst;
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = args.start + args.step * i0;
}
}
kernel void kernel_timestep_embedding_f32(
constant ggml_metal_kargs_timestep_embedding & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
int i = tgpig.x;
device float * embed_data = (device float *)(dst + i*args.nb1);
int half_ = args.dim / 2;
for (int j = tpitg.x; j < half_; j += ntg.x) {
float timestep = ((device float *)src0)[i];
float freq = (float)exp(-log((float)args.max_period) * j / half_);
float arg = timestep * freq;
embed_data[j ] = cos(arg);
embed_data[j + half_] = sin(arg);
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[2 * half_] = 0.f;
}
}
kernel void kernel_opt_step_adamw_f32(
constant ggml_metal_kargs_opt_step_adamw & args,
device float * x,
device const float * g,
device float * g_m,
device float * g_v,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const float alpha = pars[0];
const float beta1 = pars[1];
const float beta2 = pars[2];
const float eps = pars[3];
const float wd = pars[4];
const float beta1h = pars[5];
const float beta2h = pars[6];
const float gi = g[gid];
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
g_m[gid] = gmi;
g_v[gid] = gvi;
const float mh = gmi * beta1h;
const float vh = sqrt(gvi * beta2h) + eps;
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
}
kernel void kernel_opt_step_sgd_f32(
constant ggml_metal_kargs_opt_step_sgd & args,
device float * x,
device const float * g,
device const float * pars,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_memset & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;
+838
View File
@@ -0,0 +1,838 @@
#include "common.h"
#include "dequantize.h"
constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]];
constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]];
constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]];
constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]];
// each block_q contains 16*nl weights
#ifdef GGML_METAL_HAS_TENSOR
template<
typename SA, typename SA_4x4, typename SA_8x8,
typename SB, typename SB_2x4, typename SB_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * srcA,
device const char * srcB,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig [[threadgroup_position_in_grid]],
ushort tiitg [[thread_index_in_threadgroup]],
ushort sgitg [[simdgroup_index_in_threadgroup]]) {
(void) sgitg;
// Matrix dimensions: A(M,K) x B(K,N) -> C(M,N)
const int K = args.ne00;
const int M = args.ne0;
const int N = args.ne1;
// Batch dimension handling
const int im = tgpig.z;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
// Batch offsets for srcA and srcB
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
// Tile dimensions
constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X;
constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y;
// Tile offsets in output matrix
const int ra = tgpig.y * NRA;
const int rb = tgpig.x * NRB;
// Threadgroup memory for dequantized A tile only
threadgroup SA * sa = (threadgroup SA *)(shmem);
// Work-item count for A loading
constexpr int A_WORK_ITEMS = NRA * N_MM_NK;
constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y;
// tA wraps threadgroup memory
auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA));
// tB wraps device memory directly
device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13);
const int strideB = args.nb11 / sizeof(T1);
auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB}));
// Configure matmul operation
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(
NRB, NRA, N_MM_NK_TOTAL, false, true, true,
mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>();
// Accumulate partial results over K dimension
for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) {
// === PHASE 1: Dequantization of A into threadgroup memory ===
for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) {
const int row = work / N_MM_NK;
const int k_chunk = work % N_MM_NK;
const int k_pos = loop_k + k_chunk * 16;
const short k_base = k_chunk * 16;
// Bounds check: skip device read if row is out of matrix bounds
if (ra + row < M) {
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
// Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4).
// MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd,
// nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned.
// Mirrors the legacy kernel's existing guard.
device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0);
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0;
}
} else {
const int block_idx = k_pos / (16 * nl);
const short il = (k_pos / 16) % nl;
device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0);
SA_4x4 temp_a;
dequantize_func(row_ptr + block_idx, il, temp_a);
FOR_UNROLL (short i = 0; i < 16; i++) {
// Zero-pad A for K positions beyond valid range (handles partial K iterations)
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0;
}
}
} else {
// Zero-pad rows beyond matrix bounds
FOR_UNROLL (short i = 0; i < 16; i++) {
sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// === PHASE 2: Tensor matmul ===
auto mA = tA.slice(0, 0);
auto mB = tB.slice(loop_k, rb);
mm.run(mB, mA, cT);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store result tile to output matrix (with batch offset)
// cT.store handles bounds checking via tD's extents (M, N)
device float * dstBatch = (device float *)dst + im * N * M;
auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M}));
cT.store(tD.slice(ra, rb));
}
#else
template<
typename S0, typename S0_4x4, typename S0_8x8,
typename S1, typename S1_2x4, typename S1_8x8,
typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &),
typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm(
constant ggml_metal_kargs_mul_mm & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z;
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int i12 = im % FC_mul_mm_ne12;
const int i13 = im / FC_mul_mm_ne12;
const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*(r1 + lr1)
+ args.nb10*iy);
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
}
if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
// if no bounds checks on the output are needed, we can directly write to device memory
device float * C = (device float *) dst +
(r0 + 32*(sgitg & 1)) + \
(r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < nr1; j += NR1) {
device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*NR0);
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = 0;
for (; i < nr0/4; i++) {
*(D4 + i) = *(C4 + i);
}
i *= 4;
for (; i < nr0; i++) {
*(D + i) = *(C + i);
}
}
}
}
}
#endif // GGML_METAL_HAS_TENSOR
template<short ne20> // n_expert_used
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src2,
device char * htpe,
device char * hids,
threadgroup char * shmem [[threadgroup(0)]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort ntg[[threads_per_threadgroup]]) {
const short ide = tpitg; // expert id
uint32_t n_all = 0;
device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
if (i21 + tpitg < args.ne21) {
device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sids[i20] = src2_i32[i20];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short t = 0; t < ntg; t++) {
if (i21 + t >= args.ne21) {
break;
}
threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
short sel = 0;
#pragma unroll(ne20)
for (short i20 = 0; i20 < ne20; i20++) {
sel += (sids[i20] == ide)*(i20 + 1);
}
ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
n_all += sel > 0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
tpe_u32[ide] = n_all;
}
typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
device const char * htpe,
device const char * hids,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64;
constexpr int NR1 = 32;
constexpr int NK = 32;
constexpr int NL0 = NK/16;
constexpr int NL1 = NK/8;
const int im = tgpig.z; // expert
const int r0 = tgpig.y*NR0;
const int r1 = tgpig.x*NR1;
device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
device const int32_t * ids_i32 = (device const int32_t *) (hids);
const int32_t neh1 = tpe_u32[im];
if (r1 >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
// a thread shouldn't load data outside of the matrix
const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
const short il0 = (tiitg % NL0);
short il = il0;
const int id = ids_i32[im*args.ne21 + r1 + lr1];
const short i11 = (id % args.ne20) % args.ne11;
const short i12 = (id / args.ne20);
const short i13 = 0;
const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
const short offset1 = il0/nl;
device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
const short iy = 8*(tiitg % NL1);
device const T1 * y = (device const T1 *)(src1
+ args.nb13*i13
+ args.nb12*i12
+ args.nb11*i11
+ args.nb10*iy);
#ifndef GGML_METAL_HAS_TENSOR
S0_8x8 ma[4];
S1_8x8 mb[2];
simdgroup_float8x8 mc[8];
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
#else
auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
mpp::tensor_ops::matmul2d<
mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
execution_simdgroups<4>> mm;
auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
#endif
for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
#ifndef GGML_METAL_HAS_TENSOR
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
*(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
//const short lx = i%8;
//const short ly = (tiitg/NL0)%8;
const short lx = (tiitg/NL0)%8;
const short ly = i%8;
const short ib = 8*sx + sy;
// NOTE: this is massively slower.. WTF?
//sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
*(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
const short ib = 4*sx + sy;
*(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short dx = sx;
//const short dy = sy;
const short ly = (tiitg/NL1)%8;
const short ib = 4*sx + sy;
*(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
}
#else
// load data and store to threadgroup memory
if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// no need for dequantization
for (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
}
} else {
S0_4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
FOR_UNROLL (short i = 0; i < 16; i++) {
const short sx = 2*il0 + i/8;
const short sy = (tiitg/NL0)/8;
const short lx = i%8;
const short ly = (tiitg/NL0)%8;
//const short lx = (tiitg/NL0)%8;
//const short ly = i%8;
*(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
}
}
if (FC_mul_mm_bc_inp) {
for (short i = 0; i < 8; ++i) {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
}
} else {
const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8;
//const short lx = i;
const short ly = (tiitg/NL1)%8;
//const short lx = (tiitg/NL1)%8;
//const short ly = i;
*(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
}
#endif
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
y += NK;
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifndef GGML_METAL_HAS_TENSOR
// load matrices from threadgroup memory and conduct outer products
threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
}
simdgroup_barrier(mem_flags::mem_none);
FOR_UNROLL (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += 8*64;
lsmb += 4*64;
}
#else
auto sA = tA.slice(0, 0);
auto sB = tB.slice(0, 0);
mm.run(sB, sA, cT);
#endif
}
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
#ifdef GGML_METAL_HAS_TENSOR
auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
cT.store(tC);
#else
threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
}
#endif
threadgroup_barrier(mem_flags::mem_threadgroup);
for (short j = sgitg; j < nr1; j += 4) {
const int id = ids_i32[im*args.ne21 + r1 + j];
const short ide = id % args.ne20;
const short idt = id / args.ne20;
device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = (threadgroup float *) shmem + j*NR0;
threadgroup float4 * C4 = (threadgroup float4 *) C;
int i = tiisg;
for (; i < nr0/4; i += 32) {
*(D4 + i) = *(C4 + i);
}
i = (4*(nr0/4)) + tiisg;
for (; i < nr0; i += 32) {
*(D + i) = *(C + i);
}
}
}
//
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
File diff suppressed because it is too large Load Diff
+308
View File
@@ -0,0 +1,308 @@
#include "common.h"
// F == 1 : norm (no fuse)
// F == 2 : norm + mul
// F == 3 : norm + mul + add
template <typename T, short F>
kernel void kernel_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
T sumft(0.0f);
float sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumft += x[i00];
}
sumf = dot(sumft, T(1.0f));
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
sumf = 0.0f;
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
y[i00] = x[i00] - mean;
sumf += dot(y[i00], y[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float variance = sumf/args.ne00;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (y[i00]*scale);
}
if (F == 2) {
y[i00] = (y[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
// F == 1 : rms_norm (no fuse)
// F == 2 : rms_norm + mul
// F == 3 : rms_norm + mul + add
template <typename T, short F>
kernel void kernel_rms_norm_fuse_impl(
constant ggml_metal_kargs_norm & args,
device const char * src0,
device const char * src1_0,
device const char * src1_1,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
const int i01 = tgpig.x;
const int i02 = tgpig.y;
const int i03 = tgpig.z;
device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float mean = sumf/args.ne00;
const float scale = 1.0f/sqrt(mean + args.eps);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
if (F == 1) {
y[i00] = (x[i00]*scale);
}
if (F == 2) {
y[i00] = (x[i00]*scale)*f0[i00];
}
if (F == 3) {
y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
}
}
}
typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
float sumf = 0.0f;
// parallel sum
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
const float scale = 1.0f/max(sqrt(sumf), args.eps);
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
device float * dst,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t ne = args.ne00*args.ne01*args.ne02;
const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
int start = tgpig * gs;
int end = start + gs;
start += tpitg;
if (end >= ne) {
end = ne;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += ntg) {
tmp += src0[j];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float mean = tmp / gs;
tmp = 0.0f;
for (int j = start; j < end; j += ntg) {
float xi = src0[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = simd_sum(tmp);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = tmp;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
tmp = buf[tiisg];
tmp = simd_sum(tmp);
}
const float variance = tmp / gs;
const float scale = 1.0f/sqrt(variance + args.eps);
for (int j = start; j < end; j += ntg) {
dst[j] *= scale;
}
}
+148
View File
@@ -0,0 +1,148 @@
#include "common.h"
kernel void kernel_pool_2d_max_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
float res = -INFINITY;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
res = MAX(res, i_ptr[i * args.IW + j]);
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_2d_avg_f32(
constant ggml_metal_kargs_pool_2d & args,
device const float * src0,
device float * dst,
uint gid[[thread_position_in_grid]]) {
if (gid >= args.np) {
return;
}
const int idx = gid;
const int I_HW = args.IH * args.IW;
const int O_HW = args.OH * args.OW;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / args.OW;
const int cur_ow = idx % O_HW % args.OW;
device const float * i_ptr = src0 + nc * I_HW;
device float * o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * args.s1 - args.p1;
const int bh = MAX(0, start_h);
const int eh = MIN(args.IH, start_h + args.k1);
const int start_w = cur_ow * args.s0 - args.p0;
const int bw = MAX(0, start_w);
const int ew = MIN(args.IW, start_w + args.k0);
// const float scale = 1. / ((eh - bh) * (ew - bw));
const float scale = 1. / (args.k0 * args.k1);
float res = 0;
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
float cur = i_ptr[i * args.IW + j];
res += cur * scale;
}
}
o_ptr[cur_oh * args.OW + cur_ow] = res;
}
kernel void kernel_pool_1d_max_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = -INFINITY;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
int j = base + ki;
if (j < 0 || j >= args.IW){
continue;
}
float v = src[src_off + j];
acc = max(acc, v);
}
dst[dst_off + ow] = acc;
}
kernel void kernel_pool_1d_avg_f32(
constant ggml_metal_kargs_pool_1d & args,
device const float * src,
device float * dst,
uint gid [[thread_position_in_grid]]
) {
if (gid >= args.np) {
return;
}
const int ow = (int)gid % args.OW;
const int row = (int)gid / args.OW;
const int base = ow * args.s0 - args.p0;
float acc = 0.0f;
int cnt = 0;
const int src_off = row * args.IW;
const int dst_off = row * args.OW;
for (int ki = 0; ki < args.k0; ++ki) {
const int j = base + ki;
if (j < 0 || j >= args.IW) {
continue;
}
acc += src[src_off + j];
cnt += 1;
}
dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
}
+213
View File
@@ -0,0 +1,213 @@
#pragma once
#include "common.h"
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; j++) {
sum_abs += fabs(src[j]);
}
dst.d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; j++) {
dst.qs[j] = 0;
}
for (int j = 0; j < QK1_0; j++) {
if (src[j] >= 0.0f) {
dst.qs[j / 8] |= (1 << (j % 8));
}
}
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
#pragma METAL fp math_mode(safe)
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst.qs[j] = xi0;
dst.qs[j] |= xi1 << 4;
}
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK5_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -16;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK5_0/2 + j]*id;
const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
#pragma METAL fp math_mode(safe)
float max = src[0];
float min = src[0];
for (int j = 1; j < QK5_1; j++) {
const float v = src[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
dst.m = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
for (int j = 0; j < 4; ++j) {
dst.qh[j] = qh8[j];
}
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst.d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst.qs[j] = round(x0);
}
}
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_NL; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / kvalues_iq4nl_f[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
dst.qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl_f[xi0];
const float v1 = kvalues_iq4nl_f[xi1];
const float w0 = src[0 + j]*src[0 + j];
const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
}
+389
View File
@@ -0,0 +1,389 @@
#include "common.h"
#include "dequantize.h"
#include "quantize.h"
template<typename T0, typename T1>
kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
}
}
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif
template<short QK,
typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_func(src, dst_data[i00]);
break;
}
}
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
break;
}
}
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
template<typename T>
kernel void kernel_concat(
constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
if (i1 >= args.ne1) {
return;
}
int o[4] = {0, 0, 0, 0};
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
device const T * x;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
x = (device const T *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
x = (device const T *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
device T * y = (device T *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
typedef decltype(kernel_concat<float>) kernel_concat_t;
template [[host_name("kernel_concat_f32")]] kernel kernel_concat_t kernel_concat<float>;
template [[host_name("kernel_concat_f16")]] kernel kernel_concat_t kernel_concat<half>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_concat_bf16")]] kernel kernel_concat_t kernel_concat<bfloat>;
#endif
template [[host_name("kernel_concat_i8")]] kernel kernel_concat_t kernel_concat<char>;
template [[host_name("kernel_concat_i16")]] kernel kernel_concat_t kernel_concat<short>;
template [[host_name("kernel_concat_i32")]] kernel kernel_concat_t kernel_concat<int>;
template [[host_name("kernel_concat_i64")]] kernel kernel_concat_t kernel_concat<long>;
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp;
dequantize_func(psrc + ind/nl, ind%nl, temp);
pdst[ind] = temp;
break;
}
}
template<typename T0, typename T>
kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args,
device const void * src0,
device const void * src1,
device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg [[threads_per_threadgroup]]) {
const int32_t iw0 = tgpig.x/args.ne10;
const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int32_t i02 = i11;
const int32_t i03 = i12;
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
pdst[ind] = psrc[ind];
break;
}
}
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
quantize_func(src_row + 32*ind, dst_row[ind]);
}
}
template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y;
const int32_t i12 = i03%args.ne12;
const int32_t i11 = i02%args.ne11;
const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
if (i01 >= args.ne01) {
return;
}
const int32_t i10 = i01;
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
dst_row[ind] = (T) src_row[ind];
}
}
//
// get rows
//
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// set rows
//
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;

Some files were not shown because too many files have changed in this diff Show More