Compare commits

...

53 Commits

Author SHA1 Message Date
Tarek Dakhran d2462f8f7a chat: fix LFM2/LFM2.5 ignoring json_schema (#24377)
The LFM2 specialized template handler only built a grammar for tool-calling,
silently ignoring json_schema from response_format.
2026-06-10 14:41:41 +02:00
Oliver Simons fb83cc9a07 CUDA: Fix ssm_scan_f32 data-races (#24360)
* Add missing syncthreads before resuing cub_temp_storage

__syncthreads() is required before being allowed to resue TempStorage
smem:
https://nvidia.github.io/cccl/unstable/cub/api/classcub_1_1BlockLoad.html#_CPPv4I0EN3cub9BlockLoad4LoadEv20RandomAccessIteratorRA14ItemsPerThread_1Ti

* Add one more missing __syncthreads

Could also double-buffer, but alternative is to simply ensure all
threads have read smem* before writing to it again in the next loop
iteration

* Remove unused smem from ssm_scan_f32
2026-06-10 14:27:08 +02:00
Sigbjørn Skjæret 039e20a2db ci : bump komac version (#24396) 2026-06-10 09:45:20 +02:00
ddh0 d2e22ed975 speculative : fix "ngram-map-k4v" name in logging (#24253)
This is a non-functional change.

When using `--spec-type ngram-map-k4v`, the log messages at startup and
runtime say `ngram-map-k`. Added logic in the in the constructor of
`common_speculative_impl_ngram_map_k` to pass the correct
`COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V` when `config.key_only` is
`false`.

After this change, the log messages use the correct name.
2026-06-10 09:31:35 +02:00
Rémy Mathieu 76da2450a4 webui: implement pinned conversations support (#21387)
* webui: implement pinned conversations support

* webui: linter/prettier pass

* Fix the unused handleMobileSidebarItemClick from the component.

* the search should find pinned conversations as well

Co-authored-by: Pascal <admin@serveurperso.com>

---------

Co-authored-by: Pascal <admin@serveurperso.com>
2026-06-09 21:33:22 +02:00
Aarnav Pai d73cd07674 graph: Fix granite speech model inference by applying embedding scale when deepstack is not used (#24357)
* llama-graph : apply embedding scale when deepstack is not used

* nits: remove non-existant hunyuan-vl from the tests

* apply suggestion from @gabe-l-hart

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-09 19:46:27 +02:00
Sigbjørn Skjæret e25a32e98c ci : fix windows release (#24369) 2026-06-09 19:42:23 +03:00
Pascal 483609509d ui: add opt-in run_javascript frontend tool (#24244)
* ui: add opt-in run_javascript frontend tool

Expose a run_javascript tool to the model, executed entirely in the
browser through the existing agentic loop. Code runs in a Web Worker
inside a sandboxed iframe with an opaque origin, isolated from the
WebUI and its API. Console output, errors and the return value are
fed back as the tool result. The parent enforces a hard timeout by
removing the iframe, which terminates the worker.

Disabled by default, toggle in Settings > Developer.

* ui: address review feedback from allozaur

Use the JsonSchemaType enum for the tool definition parameter types
instead of raw string literals, extending it with STRING and NUMBER.

Move the worker shim and the iframe harness html into their own files
so the service no longer carries inline source blobs.

Replace the remaining magic strings with constants: SANDBOX_EMPTY_OUTPUT
and SANDBOX_TRUNCATION_NOTICE, and reuse NEWLINE_SEPARATOR for joins.

* ui: move sandbox worker shim to a raw imported file

Replace the inline worker template string with a real sandbox-worker.js
imported as raw text, and build the iframe harness from it in
sandbox-harness.ts. The raw worker ships as a string, not a module, so
it is excluded from eslint and the typecheck program.
2026-06-09 18:02:31 +02:00
Saba Fallah 49f3542190 mtmd: build_vit batching (#24352) 2026-06-09 16:32:08 +02:00
Jeff Bolz d6d0ce8215 vulkan: reduce iq1 shared memory usage for mul_mm (#24287) 2026-06-09 13:27:38 +02:00
Ruben Ortlam b4e3dc613b vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention (#24123)
* vulkan: add support for valve fp16 dot2 extension

* use macro for dot2 path choice

* properly check for the feature

* add dot_product abstraction to reduce preprocessor branching
2026-06-09 13:27:04 +02:00
Nick Towle ae735b1314 ui: Fix excessive style recalculation on hover (#24243) 2026-06-09 12:52:20 +02:00
Xuan-Son Nguyen 9682e351b8 mtmd: refactor video subproc handling (#24316)
* mtmd: refactor video subproc handling

* Update tools/mtmd/mtmd-helper.cpp

Co-authored-by: Mikko Juola <mikjuo@gmail.com>

---------

Co-authored-by: Mikko Juola <mikjuo@gmail.com>
2026-06-09 13:15:12 +03:00
jacekpoplawski 1e912561dd server: log prompts to directory (#22031)
* server: log prompts to directory

Add `--log-prompts-dir` to write each prompt to a separate text file in
the specified directory.

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-06-09 12:09:07 +02:00
Pascal efbacf8d21 ui: fix mobile chat form overflow and bust stale bundle cache (#24158) 2026-06-09 11:12:58 +02:00
Pascal 26021699bc ggml : add GGML_OP_COL2IM_1D (#24206)
* cpu: add GGML_OP_COL2IM_1D

Add the overlap-add (scatter-add) step of a 1D transposed convolution.
A ConvTranspose1d factorizes as a GEMM followed by col2im: a weight
pre-permuted to [IC, K*OC] is contracted against the [IC, T_in] input
with mul_mat to produce a column matrix [K*OC, T_in], and col2im_1d
scatters those columns back into the [T_out, OC] signal, with
T_out = (T_in - 1)*s0 + K - 2*p0.

Keeping the contraction as a plain mul_mat leaves the heavy work on the
optimized (and quantizable) matmul kernels, so col2im_1d only does the
cheap overlap-add.

CPU uses a gather formulation parallelized over output channels,
supporting F32, F16 and BF16 with an F32 accumulator.

* tests: add backend coverage for GGML_OP_COL2IM_1D

Add test_col2im_1d next to the conv_transpose_1d cases, covering F32,
F16 and BF16 across eight geometries: the canonical kernel = 2*stride
DAC upsampling shape, overlap, no overlap, cropping (p0 = 1 and
p0 = stride/2), kernel < stride with zeroed gaps, kernel not a
multiple of stride, and a single column unfold.

Perf mode gets three real vocoder stage shapes reporting memory
bandwidth. max_nmse_err relaxes to 5e-4 for F16 and BF16.

* cpu: harden GGML_OP_COL2IM_1D

ggml_col2im_1d validates s0, oc, p0 and input contiguity at graph
build time, before the oc division, protecting every backend at once.
The kernel asserts the contiguity its flat indexing assumes and its
doc states the full output length including the crop term.

The kernel parallelizes over the time axis: the split stays balanced
down to OC = 1, where the previous channel split was single threaded.
Values are bit identical on the three real vocoder chains, two out of
three improve.

* tests: extend the GGML_OP_COL2IM_1D grid

The eval grid grows to eleven geometries: OC = 1 (mono output stage),
K = 1 with stride > 1 (sparse scatter, every gap position zeroed) and
a crop down to T_out = 2 where all the gather bounds act at once.

* tests: add col2im_1d equivalence test

tests/test-col2im-1d.cpp proves mul_mat + col2im_1d matches the
native ggml_conv_transpose_1d on the CPU backend, F32 bit exact, F16
and BF16 through casts of the column matrix. test-backend-ops cannot
cover this for a CPU only op since the CPU backend is its own
reference there.

* rpc: bump protocol patch version for GGML_OP_COL2IM_1D

GGML_OP_COUNT goes from 96 to 97 with the new op, which trips the
static_assert in ggml-rpc.h. Bump RPC_PROTO_PATCH_VERSION since the
op is appended and no existing op code shifts.
2026-06-09 12:01:37 +03:00
fiesh 961e9a3e46 server : do not clear slots without unified KV cache (#24190)
* Always export idle slots to RAM

Without this, a slot's VRAM cache may not be written to RAM.  If this
slot happens to be busy then later on, this triggers needless
preprocessing in another slot.

* cont : clean-up

---------

Co-authored-by: Christoph Weiss <weiss@wsoptics.de>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-09 10:45:16 +03:00
Sigbjørn Skjæret f0152efe40 models : fix plamo2 attention_key/value_length regression (#24317) 2026-06-09 10:26:44 +03:00
Yash Raj Pandey fd3271e0b4 ggml-cpu : fix rms_norm_back wrong output under in-place aliasing (#24305)
* ggml-cpu : fix rms_norm_back wrong output under in-place aliasing

* cont : clean-up comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-09 10:24:27 +03:00
ravel7524 e3471b3e73 Remove case for GGML_TYPE_Q4_K in mvvq.cu (#23528) 2026-06-09 07:46:23 +02:00
Reese Levine 3ac3c20c96 ggml-webgpu: Add clang-format job (#24308)
* Add clang-format job

* try local formatting
2026-06-08 20:54:24 -07:00
Masashi Yoshimura 1e1aca09da ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (#24225)
* ggml-webgpu: Improve prefill speeds + refactor matmul for quants

* Fixes for editroconfig checker
2026-06-08 15:19:56 -07:00
Max Krasnyansky 7d2b45b4f7 mtp: support for gemma-4 E2B and E4B assistants (#24282)
* models: update converter to support smaller assistants

* models: add masked_embd tensors to gemma4-assist arch

* gemma-4: remove temp debug for conversion

* gemma-4-mtp: filter out masked_embedding tensors during conversion
2026-06-08 13:48:52 -07:00
Aldehir Rojas 42a0afd594 server : do not parse when flushing http headers (#24281) 2026-06-08 13:32:41 -05:00
Pascal a66d50588b graph: guard iswa kq_mask on its own buffer (#24294)
A SWA-only draft head (e.g. StepFun MTP) leaves the base sub-cache
empty, so its kq_mask buffer stays null and asserts at load. Guard
each mask on its own buffer in set_input and can_reuse, base and swa.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-08 19:20:28 +02:00
Nikhil Jain 1705d434f6 [ggml-webgpu] Handle buffer overlap / buffer aliasing for concat operator (#24000)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* handle buffer overlap case for concat operator

* restore build-webgpu.yml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Run clang-format

* Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-06-08 08:07:31 -07:00
Nikhil Jain 3b3da01dc2 [ggml-webgpu] Implement 2D workgroups for scale, binary, and unary ops (#24044)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* Implement 2d workgroups for more operations

* fix

* Fix type

* Move back to global_invocation_id
2026-06-08 08:07:15 -07:00
Xuan-Son Nguyen 3ebe862b5d docker: install ffmpeg in the released image (#24302) 2026-06-08 16:59:57 +02:00
Xuan-Son Nguyen 8f83d6c271 mtmd : add video input support (#24269)
* wip

* ok: lazy bitmap API

* remember to free lazy text

* wip

* add mtmd_helper_video

* support video input on server (base64 input)

* add MTMD_VIDEO config

* add timestamp

* update CLI

* cli: allow auto-completion for video

* add --video arg

* fix build

* update docs

* rename as suggested
2026-06-08 14:40:12 +03:00
Georgi Gerganov c2b1518fd4 sync : ggml 2026-06-08 14:31:33 +03:00
Georgi Gerganov 6a1de6fbf1 ggml : bump version to 0.14.0 (ggml/1533) 2026-06-08 14:31:33 +03:00
Xuan-Son Nguyen 715b86a366 cli: fix spinner not show during prompt processing (#24283) 2026-06-08 11:11:45 +02:00
Jeff Bolz c74759a244 vulkan: Use cm2 decode_vector for mul_mat_id B matrix loads (#23991)
This allows vec4 loads of the B elements. Also increase BK to 64 when this is
enabled. Neither of these alone is consistently faster, but together these give
a nice speedup.

In ggml-vulkan.cpp, we need to make sure the B matrix alignment and stride are
multiples of 4.
2026-06-08 10:40:37 +02:00
Ruben Ortlam 0f7fada56b cuda: reset cuda context after reading memory size (#23935)
* cuda: reset device in get_memory function if no backend is active

* also count device and host buffers

* exclude hip and musa from counting and device reset

* use device mutex instead of atomic

* undo backend_free function move
2026-06-08 10:22:44 +02:00
Harkirat Gill 19bba67c1f HIP: add gfx1152 and gfx1153 to RDNA3.5 (#24129) 2026-06-08 08:33:23 +02:00
Xuan-Son Nguyen daf6bc9f2d metal : fix im2col 1D case (audio models) (#24220) 2026-06-08 09:03:18 +03:00
Neo Zhang d403f00ec3 [SYCL] Update compute runtime version to 26.x in docker (#24070)
* update compute runtime from 25 to 26 in docker

* add comment with old driver for multiple GPUs
2026-06-08 10:35:18 +08:00
ddh0 9e3b928fd8 common : relax sampler name matching (#23744)
* common : relax sampler name matching

Currently, in some cases, the alternative names for samplers (like
`top-k` and `min-p` instead of the canonical `top_k` and `min_p`) are
not always recognized by the `common_sampler_types_from_names` function
in `common/sampling.cpp`.

This PR changes the signature of this function to remove the `bool
allow_alt_names` flag, and removes all occurences of the flag from call
sites. Therefore, the function will now always match all known names.

I also changed the logic of the function to unconditionally check the
provided sampler names against both the canonical and alternative names,
and to be case-insensitive.

This fixes an issue I was seeing wherein samplers specified in the
`llama-server` UI were not recognized as valid when the alternative
names were used.

* add more alt names

* cont. fix

* cast to unsigned char for correctness

* common : unify sampler name mapping

* annotate canonical vs. alt sampler name mappings per @CISC

* Update common/sampling.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* common : auto-generate sampler name aliases per @ngxson

* use merged map for matching

* use `.merge` instead of iterating

* nit: simplify comment

* nit: use insert everywhere, not index assignment

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-06-07 22:48:11 +02:00
David Friehs 8a963fc10e convert : fix conversion for Mistral-Medium-3.5-128B (#24268)
Mistral explicitly sets `moe` and `llama_4_scaling` to `null` in
params.json, breaking `key in dict` checks during conversion. Replace
with `dict.get(key) is not None` where this matters.

Fixes `convert-hf-to-gguf.py --mistral-format Mistral-Medium-3.5-128B`
2026-06-07 21:41:39 +02:00
Georgi Gerganov 379ac6673b kv-cache : avoid kv cells copies (#24277) 2026-06-07 21:42:54 +03:00
Pascal f0156d1401 kv-cache: follow the source cache size when sharing cells (#24267)
A fitted target context can end up smaller than the draft default, the
oversized assistant views then overflow the shared K/V tensors and trip
the ggml_view_4d size assert during graph reserve.
2026-06-07 18:33:00 +03:00
Aman Gupta 04eb4c446d llama : add Gemma4 MTP (#23398) 2026-06-07 20:50:54 +08:00
Sigbjørn Skjæret 8a091c47ab spec : fix vocab compatibility check (#24256) 2026-06-07 14:43:52 +03:00
konradmb 465b1f0e75 arg: Skip mmproj download when user supplied mmproj (#24239) 2026-06-07 11:18:44 +02:00
Sigbjørn Skjæret f71af352a5 convert : fix Gemma4 with no audio encoder (#24242) 2026-06-07 08:43:05 +02:00
Sigbjørn Skjæret 3f7c79d7b5 docker : bump cuda13 to 13.3.0 (#24228) 2026-06-07 08:31:58 +02:00
Tarek Dakhran 98d5e8ba8a common/chat : fix LFM2/LFM2.5 reasoning round-trip and <think> leak (#24234)
* common/chat : fix LFM2 reasoning round-trip and stray <think> leak
* Gate by reasoning format and whether the template supports <think>
2026-06-06 22:39:21 +02:00
Xuan-Son Nguyen 31e82494c0 mtmd: support "frame merge" for qwen-vl-based models (#21858)
* feat: add video support for Qwen3.5

* various clean up

* revise the design

* fix llava-uhd case

* nits

* nits 2

---------

Co-authored-by: andrewmd5 <1297077+andrewmd5@users.noreply.github.com>
2026-06-06 21:17:25 +02:00
Adrien Gallouët 6b80c74f28 completion : remove useless statics (#24226)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-06 12:16:16 +02:00
Adrien Gallouët 588f0dc2ce completion : fix format specifier in LOG_INF (#24213)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-06 11:24:27 +02:00
Xuan-Son Nguyen f5c6ae1827 mtmd, server: add "placeholder bitmap" for counting tokens , add */input_tokens API (#23913)
* mtmd: add "placeholder bitmap" for counting tokens w/o preprocessing

* fast path skip preproc for placeholder

* fix build

* correct the api

* add server endpoint + tests

* add object name

* update docs

* add proxy handling

* fix build

* fix audio input path

* use is_placeholder in process_mtmd_prompt()

* nits

* nits (2)

* docs: clarify chat/completions/input_tokens is not official

* fix merge problem
2026-06-06 11:06:51 +02:00
Ruben Ortlam 5a69c97439 vulkan: check coopmat2 features before reporting support (#24186) 2026-06-06 09:11:35 +02:00
Sigbjørn Skjæret 5343f4502a model : rename local n_layer_all variable (#24209) 2026-06-06 07:07:20 +03:00
143 changed files with 4487 additions and 1696 deletions
+1 -1
View File
@@ -53,7 +53,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -59,7 +59,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+16 -6
View File
@@ -57,11 +57,21 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.url=$IMAGE_URL \
org.opencontainers.image.source=$IMAGE_SOURCE
ARG IGC_VERSION=v2.20.5
ARG IGC_VERSION_FULL=2_2.20.5+19972
ARG COMPUTE_RUNTIME_VERSION=25.40.35563.10
ARG COMPUTE_RUNTIME_VERSION_FULL=25.40.35563.10-0
ARG IGDGMM_VERSION=22.8.2
#Following versions are for multiple GPUs, since 26.x has known issue:
# https://github.com/ggml-org/llama.cpp/issues/21747,
# https://github.com/intel/compute-runtime/issues/921.
#ARG IGC_VERSION=v2.20.5
#ARG IGC_VERSION_FULL=2_2.20.5+19972
#ARG COMPUTE_RUNTIME_VERSION=25.40.35563.10
#ARG COMPUTE_RUNTIME_VERSION_FULL=25.40.35563.10-0
#ARG IGDGMM_VERSION=22.8.2
ARG IGC_VERSION=v2.34.4
ARG IGC_VERSION_FULL=2_2.34.4+21428
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
ARG IGDGMM_VERSION=22.10.0
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \
@@ -75,7 +85,7 @@ RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& dpkg --install *.deb
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -64,7 +64,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -107,7 +107,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \
&& apt-get install -y libgomp1 libtbb12 curl wget ffmpeg ocl-icd-libopencl1 \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -76,7 +76,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -49,7 +49,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
&& apt-get install -y libgomp1 curl ffmpeg libvulkan1 mesa-vulkan-drivers \
libglvnd0 libgl1 libglx0 libegl1 libgles2 \
&& apt autoremove -y \
&& apt clean -y \
+1 -1
View File
@@ -46,7 +46,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 libnuma1 curl \
&& apt-get install -y libgomp1 libnuma1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+23
View File
@@ -35,6 +35,29 @@ env:
LLAMA_ARG_LOG_TIMESTAMPS: 1
jobs:
format:
runs-on: ubuntu-24.04
steps:
- name: Clone
uses: actions/checkout@v6
- name: Install clang-format 22
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key |
sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
sudo add-apt-repository -y \
"deb http://apt.llvm.org/noble/ llvm-toolchain-noble-22 main"
sudo apt-get update
sudo apt-get install -y clang-format-22
- name: Check formatting
run: |
find ggml/src/ggml-webgpu \
-type f \( -name '*.cpp' -o -name '*.hpp' -o -name '*.h' \) \
-print0 |
xargs -0 clang-format-22 --dry-run --Werror
macos:
runs-on: macos-latest
+2 -2
View File
@@ -82,8 +82,8 @@ jobs:
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
+5 -5
View File
@@ -504,7 +504,7 @@ jobs:
needs: [check-release]
if: ${{ needs.check-release.outputs.should_release == 'true' }}
runs-on: windows-2025
runs-on: windows-2025-vs2026
permissions:
actions: write
@@ -535,12 +535,12 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.21
with:
key: release-windows-2025-${{ matrix.arch }}-cpu
key: release-windows-2025-vs2026-${{ matrix.arch }}-cpu
- name: Build
shell: cmd
run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }}
call "C:\Program Files\Microsoft Visual Studio\18\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch == 'x64' && 'x64' || 'amd64_arm64' }}
cmake -S . -B build -G "Ninja Multi-Config" ^
-D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^
-DLLAMA_BUILD_BORINGSSL=ON ^
@@ -554,12 +554,12 @@ jobs:
- name: ccache-clear
uses: ./.github/actions/ccache-clear
with:
key: release-windows-2025-${{ matrix.arch }}-cpu
key: release-windows-2025-vs2026-${{ matrix.arch }}-cpu
- name: Pack artifacts
id: pack_artifacts
run: |
Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.44.35112\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
Copy-Item "C:\Program Files\Microsoft Visual Studio\18\Enterprise\VC\Redist\MSVC\14.51.36231\debug_nonredist\${{ matrix.arch }}\Microsoft.VC145.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
- name: Upload artifacts
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
- name: Install komac
run: |
cargo binstall komac@2.15.0 -y
cargo binstall komac@2.16.0 -y
- name: Find latest release
id: find_latest_release
+12 -5
View File
@@ -444,7 +444,7 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
opts.offline = params.offline;
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
@@ -1360,7 +1360,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--cache-idle-slots"},
{"--no-cache-idle-slots"},
"save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)",
"save idle slots to the prompt cache on new task, and clear them when using unified KV (default: enabled, requires cache-ram)",
[](common_params & params, bool value) {
params.cache_idle_slots = value;
}
@@ -1615,7 +1615,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.samplers = common_sampler_types_from_names(sampler_names);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
}
).set_sampling());
@@ -2221,8 +2221,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
{"--image", "--audio", "--video"}, "FILE",
"path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) {
params.image.emplace_back(item);
@@ -3333,6 +3333,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str());
}
).set_env("LLAMA_ARG_LOG_FILE"));
add_opt(common_arg(
{"--log-prompts-dir"}, "PATH",
"Log prompts to directory (only used for debugging, default: disabled)",
[](common_params & params, const std::string & value) {
params.path_prompts_log_dir = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--log-colors"}, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
+27 -7
View File
@@ -1625,8 +1625,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
const std::string THINK_END = "</think>";
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
// Copy reasoning to the "thinking" field the template expects
auto adjusted_messages = json::array();
for (auto msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
msg["thinking"] = msg.at("reasoning_content");
}
adjusted_messages.push_back(msg);
}
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, adjusted_messages);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = { TOOL_CALL_START, TOOL_CALL_END, THINK_START, THINK_END };
@@ -1638,9 +1647,12 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.thinking_start_tag = THINK_START;
data.thinking_end_tag = THINK_END;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = !inputs.json_schema.is_null() && inputs.json_schema.is_object();
// Gate by reasoning format and whether the template supports <think>
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE &&
tmpl.source().find(THINK_START) != std::string::npos;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
@@ -1658,11 +1670,15 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto end = p.end();
auto reasoning = p.eps();
if (extract_reasoning && inputs.enable_thinking) {
if (extract_reasoning) {
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
}
if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
if (has_response_format) {
auto response_format = p.content(p.schema(p.json(), "response-format-schema", inputs.json_schema));
return generation_prompt + reasoning + response_format + end;
}
return generation_prompt + reasoning + p.content(p.rest()) + end;
}
auto tool_calls = p.rule("tool-calls",
@@ -1681,13 +1697,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
+1 -1
View File
@@ -1148,7 +1148,7 @@ static void common_init_sampler_from_model(
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
sparams.samplers = common_sampler_types_from_names(sampler_names);
}
}
}
+2 -1
View File
@@ -489,6 +489,7 @@ struct common_params {
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string path_prompts_log_dir = ""; // directory with logged prompts // NOLINT
// llama-debug specific options
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
@@ -571,7 +572,7 @@ struct common_params {
struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)
std::vector<std::string> image; // path to image file(s) ; TODO: change the name to "media"
int image_min_tokens = -1;
int image_max_tokens = -1;
+49 -40
View File
@@ -769,54 +769,63 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
}
}
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
// since samplers names are written multiple ways
// make it ready for both system names and input names
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names) {
// sampler names can be written multiple ways; generate aliases from canonical names
static const auto sampler_name_map = []{
// canonical sampler name mapping
std::unordered_map<std::string, common_sampler_type> canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }
};
std::unordered_map<std::string, common_sampler_type> alias_name_map;
for (const auto & entry : canonical_name_map) {
const std::string & canonical = entry.first;
if (canonical.find('_') == std::string::npos) {
continue;
}
// kebab-case: "top-k", "min-p", etc.
{
std::string kebab_case = canonical;
std::replace(kebab_case.begin(), kebab_case.end(), '_', '-');
alias_name_map.insert({kebab_case, entry.second});
}
// no dash: "topk", "minp", etc.
{
std::string no_dash = canonical;
no_dash.erase(std::remove(no_dash.begin(), no_dash.end(), '_'), no_dash.end());
alias_name_map.insert({no_dash, entry.second});
}
}
// misc. aliases
alias_name_map.insert({"nucleus", COMMON_SAMPLER_TYPE_TOP_P});
alias_name_map.insert({"temp", COMMON_SAMPLER_TYPE_TEMPERATURE});
alias_name_map.insert({"typ", COMMON_SAMPLER_TYPE_TYPICAL_P});
// include aliases + canonical names in the complete mapping
alias_name_map.merge(canonical_name_map);
return alias_name_map;
}();
std::vector<common_sampler_type> samplers;
samplers.reserve(names.size());
for (const auto & name : names) {
auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) {
std::string name_lower = name;
std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower);
auto sampler = sampler_name_map.find(name_lower);
if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
}
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name_lower.c_str());
}
return samplers;
+1 -1
View File
@@ -109,7 +109,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
+55 -45
View File
@@ -3,13 +3,14 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include <algorithm>
#include <cassert>
#include <cstring>
@@ -58,10 +59,10 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
@@ -418,6 +419,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
@@ -444,7 +447,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -490,6 +495,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_batch_beg.assign(n_seq, -1);
@@ -526,9 +533,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
@@ -571,48 +580,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!is_mem_shared) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -721,7 +724,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
@@ -834,7 +843,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
common_speculative_impl_ngram_map_k(
const common_ngram_map & config,
uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, n_seq)
: common_speculative_impl(config.key_only ? COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K
: COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, n_seq)
{
for (uint32_t i = 0; i < n_seq; i++) {
this->config.push_back(config);
+2
View File
@@ -75,9 +75,11 @@ TEXT_MODEL_MAP: dict[str, str] = {
"Gemma3TextModel": "gemma",
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4AssistantForCausalLM": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",
"Gemma4UnifiedAssistantForCausalLM": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
"Glm4MoeForCausalLM": "glm",
+25 -4
View File
@@ -785,6 +785,26 @@ class Gemma4UnifiedModel(Gemma4Model):
self.gguf_writer.add_suppress_tokens(suppress_tokens)
@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if "masked_embedding" in name:
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return None
return super().filter_tensors(item)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
self.gguf_writer.add_nextn_predict_layers(self.block_count)
@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True
@@ -812,10 +832,11 @@ class Gemma4VisionAudioModel(MmprojModel):
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
# audio params
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
if self.has_audio_encoder:
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
def is_audio_tensor(self, name: str) -> bool:
return "audio_tower" in name or "embed_audio" in name
+3 -2
View File
@@ -105,8 +105,9 @@ class MistralModel(LlamaModel):
gguf_writer.add_rope_scaling_yarn_log_mul(mscale_all_dim)
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
if "llama_4_scaling" in hparams:
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
llama_4_scaling = hparams.get("llama_4_scaling")
if llama_4_scaling is not None:
gguf_writer.add_attn_temperature_scale(llama_4_scaling["beta"])
class MistralMoeModel(DeepseekV2Model):
+1 -1
View File
@@ -238,7 +238,7 @@ def main() -> None:
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
from conversion.pixtral import PixtralModel
model_class = PixtralModel
elif "moe" in hparams:
elif hparams.get("moe") is not None:
from conversion.mistral import MistralMoeModel
model_class = MistralMoeModel
else:
+2 -2
View File
@@ -4,8 +4,8 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 13)
set(GGML_VERSION_PATCH 1)
set(GGML_VERSION_MINOR 14)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
+2 -2
View File
@@ -8,10 +8,10 @@ extern "C" {
#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define RPC_PROTO_PATCH_VERSION 1
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
#endif
#define GGML_RPC_MAX_SERVERS 16
+11
View File
@@ -535,6 +535,7 @@ extern "C" {
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_IM2COL_3D,
GGML_OP_COL2IM_1D,
GGML_OP_CONV_2D,
GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
@@ -2007,6 +2008,16 @@ extern "C" {
int d1, // dilation dimension 1
bool is_2D);
// col2im_1d: scatter-add GEMM columns back to 1D signal
// a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC)
// result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0
GGML_API struct ggml_tensor * ggml_col2im_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // columns [K*OC, T_in]
int s0, // stride
int oc, // output channels
int p0); // padding to crop from both sides
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
+5
View File
@@ -1912,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_im2col_3d(params, tensor);
} break;
case GGML_OP_COL2IM_1D:
{
ggml_compute_forward_col2im_1d(params, tensor);
} break;
case GGML_OP_CONV_2D:
{
ggml_compute_forward_conv_2d(params, tensor);
@@ -2343,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_COL2IM_1D:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_2D:
{
+78 -6
View File
@@ -4008,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
ggml_vec_acc_f32 (ne00, dx, dz);
ggml_vec_scale_f32(ne00, dx, rrms);
// dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms
// note: https://github.com/ggml-org/ggml/issues/1491
const float scale_x = (float) (-sum_xdz) / sum_eps;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms;
}
}
}
}
@@ -6730,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
return (coord + size) % size; // adding size avoids negative number weirdness
}
// ggml_compute_forward_col2im_1d
//
// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC]
// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs.
// Parallelized over the time axis so the split stays balanced whatever OC is.
// Supports F32, F16, BF16 input/output (same type), F32 accumulator.
template <typename elem_t>
static void ggml_compute_forward_col2im_1d_impl(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0]; // [K*OC, T_in]
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_is_contiguous(dst));
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int64_t K_OC = src->ne[0];
const int64_t T_in = src->ne[1];
const int64_t K = K_OC / OC;
const int64_t T_out = dst->ne[0];
const elem_t * col_data = (const elem_t *) src->data;
elem_t * dst_data = (elem_t *) dst->data;
const int ith = params->ith;
const int nth = params->nth;
// Parallelize over the time axis: the split stays balanced whatever OC is,
// down to OC = 1 for mono audio, and threads read disjoint column bands
const int64_t dr = (T_out + nth - 1) / nth;
const int64_t it0 = dr * ith;
const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out;
for (int64_t oc = 0; oc < OC; oc++) {
for (int64_t t_out = it0; t_out < it1; t_out++) {
const int64_t t_abs = t_out + p0; // absolute position in uncropped signal
// Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K
int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s)
if (t_in_min < 0) t_in_min = 0;
int64_t t_in_max = t_abs / s0;
if (t_in_max >= T_in) t_in_max = T_in - 1;
float sum = 0.0f;
for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
int64_t k = t_abs - t_in * s0;
if (k >= 0 && k < K) {
// col layout: [K*OC, T_in], element (oc*K+k, t_in)
sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]);
}
}
// dst layout: [T_out, OC], element (t_out, oc)
dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum);
}
}
}
void ggml_compute_forward_col2im_1d(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break;
case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break;
case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break;
default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type);
}
}
// ggml_compute_forward_conv_2d
+1
View File
@@ -68,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+66 -9
View File
@@ -622,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
// cuda buffer
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::mutex device_mutex;
int active_count = 0;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
};
struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
@@ -639,6 +651,13 @@ struct ggml_backend_cuda_buffer_context {
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete ctx;
}
@@ -791,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}
@@ -1490,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
}
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
CUDA_CHECK(cudaFreeHost(buffer->context));
}
@@ -1498,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
return nullptr;
}
ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.
void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
@@ -1523,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return buffer;
}
@@ -3140,6 +3179,12 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
static void ggml_backend_cuda_free(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete cuda_ctx;
delete backend;
}
@@ -4871,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
// backend device
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->name.c_str();
@@ -4967,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::lock_guard<std::mutex> lock(ctx->device_mutex);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemGetInfo(free, total));
@@ -4993,6 +5035,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
}
#endif // defined(__linux__)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
// context that permanently consumes VRAM. Reset the device to free it.
if (ctx->active_count == 0) {
CUDA_CHECK(cudaDeviceReset());
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -5687,13 +5736,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
return nullptr;
}
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);
ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .device = */ dev,
/* .context = */ ctx,
};
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return cuda_backend;
}
-1
View File
@@ -411,7 +411,6 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
return 8;
case GGML_TYPE_Q6_K:
return 2;
+3 -2
View File
@@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1)
__shared__ CubTempStorage cub_temp_storage;
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
__syncthreads();
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
#else
const int stride_s0 = src0_nb2 / sizeof(float);
@@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1)
regs0[n] = state;
}
y_block[i * stride_y + threadIdx.x] = sumf;
__syncthreads();
}
#ifdef USE_CUB
@@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream);
switch (n_tok)
{
case 1:
+2 -2
View File
@@ -219,9 +219,9 @@
#define RDNA3
#endif // defined(__GFX11__)
#if defined(__gfx1150__) || defined(__gfx1151__)
#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
#define RDNA3_5
#endif // defined(__gfx1150__) || defined(__gfx1151__)
#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
#if defined(RDNA3) && !defined(RDNA3_5)
#define RDNA3_0
+5 -1
View File
@@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
char base[256];
char name[256];
if (ne00*ne01 <= 1024) {
if (KH*KW <= 1024) {
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
} else {
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
+228 -31
View File
@@ -113,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
#endif
#if !defined(VK_VALVE_shader_mixed_float_dot_product)
#define VK_VALVE_shader_mixed_float_dot_product 1
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product"
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000)
typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE {
VkStructureType sType;
void* pNext;
VkBool32 shaderMixedFloatDotProductFloat16AccFloat32;
VkBool32 shaderMixedFloatDotProductFloat16AccFloat16;
VkBool32 shaderMixedFloatDotProductBFloat16Acc;
VkBool32 shaderMixedFloatDotProductFloat8AccFloat32;
} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE;
#endif
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -705,6 +720,8 @@ struct vk_device_struct {
bool coopmat2_bf16_support {};
bool coopmat2_decode_vector;
bool dot2_f16 {};
bool pipeline_executable_properties_support {};
size_t idx;
@@ -1976,6 +1993,9 @@ struct ggml_backend_vk_context {
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
const ggml_tensor * prealloc_y_last_tensor_used {};
// True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback.
// If false, then it's contiguous.
bool prealloc_y_last_decode_vector_staging {};
// Track which nodes have been used since the last sync, and whether they were written to
std::vector<const ggml_tensor *> unsynced_nodes_written;
@@ -3374,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
switch (src0_type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
lut_size = 2*2048 + 4*2048;
// Regular matmul uses the compact uint16_t IQ1 grid; the expanded
// uint32_t grid is only enabled for the q8_1/int-dot vector path.
lut_size = 2*2048;
break;
case GGML_TYPE_IQ2_XXS:
lut_size = 8*256;
@@ -3652,9 +3674,10 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u;
l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -3916,8 +3939,13 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
} else {
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
if (device->dot2_f16) {
if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; }
else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; }
} else {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
}
} else {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
@@ -4211,7 +4239,23 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
@@ -4246,7 +4290,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -4254,7 +4298,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -4294,8 +4337,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -4340,8 +4382,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -4386,6 +4427,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
#undef CREATE_MM_NODOT2
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
@@ -5449,6 +5491,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->integer_dot_product = false;
device->shader_64b_indexing = false;
bool bfloat16_support = false;
bool dot2_f16_support = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -5491,6 +5534,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_DOT2")) {
dot2_f16_support = true;
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
pipeline_executable_properties_support = true;
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
@@ -5798,6 +5844,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
}
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
if (dot2_f16_support) {
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
last_struct = (VkBaseOutStructure *)&dot2_features;
device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product");
}
VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
if (pipeline_executable_properties_support) {
@@ -5832,6 +5886,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->bf16 = false;
#endif
device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
@@ -6246,6 +6302,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool coopmat2_decode_vector_support = false;
bool integer_dot_product = false;
bool bfloat16_support = false;
bool dot2_f16_support = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -6275,6 +6332,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_DOT2")) {
dot2_f16_support = true;
}
}
@@ -6349,6 +6409,15 @@ static void ggml_vk_print_gpu_info(size_t idx) {
}
#endif
#if defined(VK_NV_cooperative_matrix2)
VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
if (coopmat2_support) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
last_struct = (VkBaseOutStructure *)&coopmat2_features;
}
#endif
VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
if (coopmat2_decode_vector_support) {
@@ -6356,6 +6425,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
}
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
if (dot2_f16_support) {
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
last_struct = (VkBaseOutStructure *)&dot2_features;
}
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
@@ -6380,6 +6456,19 @@ static void ggml_vk_print_gpu_info(size_t idx) {
#endif
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
coopmat2_support = coopmat2_support &&
coopmat2_features.cooperativeMatrixWorkgroupScope &&
coopmat2_features.cooperativeMatrixFlexibleDimensions &&
coopmat2_features.cooperativeMatrixReductions &&
coopmat2_features.cooperativeMatrixConversions &&
coopmat2_features.cooperativeMatrixPerElementOperations &&
coopmat2_features.cooperativeMatrixTensorAddressing &&
coopmat2_features.cooperativeMatrixBlockLoads;
#else
coopmat2_support = false;
#endif
coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
coopmat2_decode_vector_support = false;
@@ -6389,9 +6478,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
: coopmat_support ? "KHR_coopmat"
: "none";
bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@@ -8088,6 +8180,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
// Copy/convert tensor into a caller-defined dense layout. Destination strides
// are in output elements, not bytes.
static void ggml_vk_cpy_to_strided(
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor,
const vk_subbuffer & in, const vk_subbuffer & out,
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) {
VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
const int tensor_type_size = ggml_type_size(tensor->type);
const uint32_t ne = ggml_nelements(tensor);
std::array<uint32_t, 3> elements;
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
vk_op_unary_push_constants pc = {
(uint32_t)ne,
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
init_pushconst_fastdiv(pc);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
}
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
switch(type) {
case GGML_TYPE_Q8_1:
@@ -8345,24 +8471,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -8620,24 +8750,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -9088,12 +9222,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
// If src0 is BF16, try to use a BF16 x BF16 multiply
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
// B must already be, or be convertible to, the matmul B type used by this path.
const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector &&
(f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) &&
(src1->type == GGML_TYPE_F32 || src1->type == f16_type);
// If B is copied to prealloc_y, we can choose a 4-element-aligned row stride.
const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type;
// Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned.
const bool y_decode_vector_aligned =
(ne10 % 4 == 0) &&
(y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0);
// Stage B only when decode-vector is available and direct B reads would be misaligned.
const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned;
#else
const bool y_decode_vector_staging = false;
#endif
const bool y_non_contig = y_decode_vector_staging ||
(ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
!ggml_vk_dim01_contiguous(src1);
// If src0 is BF16, try to use a BF16 x BF16 multiply
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10;
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
@@ -9132,11 +9284,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13;
const uint64_t d_ne = ggml_nelements(dst);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
@@ -9146,13 +9298,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
auto make_y_staged_dst = [&]() {
ggml_tensor y_staged_dst = *src1;
y_staged_dst.type = f16_type;
y_staged_dst.nb[0] = ggml_type_size(f16_type);
y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride;
y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n;
y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2];
return y_staged_dst;
};
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
} else {
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
}
if (y_non_contig) {
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
ggml_tensor y_staged_dst;
const ggml_tensor * y_staged_dst_ptr = nullptr;
if (y_decode_vector_staging) {
y_staged_dst = make_y_staged_dst();
y_staged_dst_ptr = &y_staged_dst;
}
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type);
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
@@ -9270,30 +9439,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
if (y_decode_vector_staging) {
const ggml_tensor y_staged_dst = make_y_staged_dst();
const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type);
ggml_vk_cpy_to_strided(
ctx, subctx, to_fp16_vk_1, src1,
ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0),
(uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size));
} else {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
}
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
ggml_vk_sync_buffers(ctx, subctx);
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10;
uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11;
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@@ -9308,7 +9494,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
ne01, ne21, ne10, ne10, ne10, ne01,
ne01, ne21, ne10, ne10, stride_b_y, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT
@@ -9466,24 +9652,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -13708,7 +13898,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
ggml_vk_destroy_buffer(ctx->prealloc_y);
}
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
}
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -14288,6 +14480,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
@@ -14338,6 +14532,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ggml_vk_destroy_buffer(ctx->sync_staging);
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0;
@@ -15517,6 +15713,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
if (ctx->prealloc_size_add_rms_partials) {
ggml_vk_preallocate_buffers(ctx, nullptr);
@@ -0,0 +1,27 @@
#ifdef DOT2_F16
#extension GL_EXT_spirv_intrinsics : require
spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"],
capabilities = [6912], id = 6916)
float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc);
ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc))));
}
ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc)));
}
#else
ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y),
fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc))));
}
ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc));
}
#endif
@@ -21,6 +21,7 @@
#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "dot_product_funcs.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
@@ -318,7 +319,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]);
}
}
}
@@ -341,7 +342,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]);
}
}
}
@@ -4,6 +4,7 @@
#extension GL_EXT_integer_dot_product : require
#define MMQ
#define NEEDS_IQ1S_GRID_GPU
#define B_TYPE block_q8_1_x4
#include "mul_mat_vec_base.glsl"
@@ -29,6 +29,7 @@
#endif
#include "types.glsl"
#include "dot_product_funcs.glsl"
#ifndef LOAD_VEC_A
#define LOAD_VEC_A 1
@@ -329,15 +330,8 @@ void main() {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#if defined(DATA_A_F32) || defined(DATA_A_F16)
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
#else
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
#endif
sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x);
sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y);
}
}
}
@@ -11,6 +11,9 @@
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR
#extension GL_NV_cooperative_matrix_decode_vector : enable
#endif
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
@@ -69,10 +72,13 @@ layout (push_constant) uniform parameter
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];};
#endif
#if QUANT_K > 1
#include "dequant_funcs_cm2.glsl"
#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector)
#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
#define DECODEFUNCA , dequantFuncA, dequantFuncA_v
#else
#define DECODEFUNCA , dequantFuncA
@@ -113,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
const uint row_i = blockCoords[0];
const u16vec4 row_idx = row_ids[row_i];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
// The decode-vector path gives B a K-dimension tensor-layout block size of BK.
const uint k = blockCoords[1] * BK + coordInBlock[1];
#else
const uint k = blockCoords[1];
#endif
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k];
return ret;
}
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint row_i = blockCoords[0];
const u16vec4 row_idx = row_ids[row_i];
const uint k = blockCoords[1] * BK + coordInBlock[1];
const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k;
return data_b_v4[base >> 2];
}
#define DECODEFUNCB , decodeFuncB, decodeFuncB_v
#else
#define DECODEFUNCB , decodeFuncB
#endif
D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
{
uint dr = ir * BM + r;
@@ -287,6 +315,9 @@ void main() {
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
#endif
#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK);
#endif
// Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k.
@@ -499,7 +530,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -507,7 +538,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -543,7 +574,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -551,7 +582,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -588,7 +619,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@@ -600,7 +631,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@@ -598,9 +598,10 @@ const uint[1024] iq1s_grid_const = {
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
};
#if defined(NEEDS_IQ1S_GRID_GPU)
// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
// and 0xF0F0F0F0).
// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path.
const uint32_t[2048] iq1s_grid_gpu_const = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
@@ -859,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = {
0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
};
#endif
shared uint16_t iq1s_grid[2048];
#if defined(NEEDS_IQ1S_GRID_GPU)
shared uint32_t iq1s_grid_gpu[2048];
#endif
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
@@ -875,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize)
iq1s_grid[2*idx+1] = g.y;
}
}
#if defined(NEEDS_IQ1S_GRID_GPU)
[[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
uint idx = i + gl_LocalInvocationIndex.x;
if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
}
}
#endif
barrier();
}
#endif
@@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
// disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability)
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) {
cmd.push_back("-O");
}
@@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
generate_dep_file = false;
}
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
std::string dot2_sfx = dot2 ? "_dot2" : "";
std::map<std::string, std::string> base_dict;
std::string shader_name = "matmul";
@@ -457,6 +459,15 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (coopmat) {
base_dict["COOPMAT"] = "1";
}
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
if (coopmat2) {
base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1";
}
#endif
if (dot2) {
base_dict["DOT2_F16"] = "1";
}
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
@@ -523,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -548,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (!(coopmat || coopmat2))
#endif
{
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -579,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
// Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2)
if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -608,6 +621,10 @@ void process_shaders() {
matmul_shaders(true, matmul_id_type, false, false, false);
matmul_shaders(true, matmul_id_type, false, false, true);
// dot2 variants (scalar fp16 only)
matmul_shaders(true, matmul_id_type, false, false, false, true);
matmul_shaders(true, matmul_id_type, false, false, true, true);
if (matmul_id_type != MatMulIdType::DEFAULT) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
// Coopmat, fp32acc and fp16acc
@@ -655,6 +672,12 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
if (fp16) {
string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
+25 -11
View File
@@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash {
/** Concat **/
struct ggml_webgpu_concat_pipeline_key {
int type;
int type;
bool src_overlap;
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
return type == other.type && src_overlap == other.src_overlap;
}
};
struct ggml_webgpu_concat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
@@ -640,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
const uint32_t offset_elems =
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
ggml_type_size(K->type));
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
}
@@ -651,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
}
inline bool ggml_webgpu_flash_attn_kv_direct(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V,
uint32_t kv_direct_align) {
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
}
@@ -667,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
return key;
}
@@ -1723,7 +1730,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@@ -2634,6 +2641,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_concat_pipeline_key key = {};
key.type = context.dst->type;
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = concat_pipelines.find(key);
if (it != concat_pipelines.end()) {
@@ -2656,11 +2664,17 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported type for concat shader");
}
if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
+69 -44
View File
@@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
uint32_t value,
size_t offset,
size_t size) {
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg =
ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
@@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
@@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
}
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
@@ -2305,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
@@ -2339,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
size_t merged_offset = 0;
size_t merged_size = 0;
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
}
std::vector<uint32_t> params = { ne,
offset_src0,
offset_src1,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim] };
std::vector<wgpu::BindGroupEntry> entries = {};
if (decisions->src_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
@@ -2673,8 +2695,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -3751,7 +3775,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup *
ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
ctx->capabilities.memset_bytes_per_thread =
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
@@ -4228,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
const uint32_t q_tile =
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
@@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
let src0_i = params.offset_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i);
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i < params.ne) {
let src0_i = params.offset_src0 + src0_index(i);
let src1_i = params.offset_src1 + src1_index(i);
update(params.offset_dst + i, src0_i, src1_i);
}
}
+19 -1
View File
@@ -31,6 +31,16 @@ struct Params {
#define DataType i32
#endif
#ifdef SRC_OVERLAP
@group(0) @binding(0)
var<storage, read_write> merged_src: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
@@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
ni[1] * params.stride_src0_1 +
ni[2] * params.stride_src0_2 +
ni[3] * params.stride_src0_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i];
#else
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
#endif
} else {
ni[params.dim] -= params.src0_nedim;
let src_i = ni[0] * params.stride_src1_0 +
ni[1] * params.stride_src1_1 +
ni[2] * params.stride_src1_2 +
ni[3] * params.stride_src1_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i];
#else
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
#endif
}
}
}
@@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // INIT_SRC0_SHMEM_Q1_0
#ifdef INIT_SRC0_SHMEM_Q4_0
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1)
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
#else
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
#endif
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let tile_m = block_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let block_k = block_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#ifdef INIT_SRC0_SHMEM_Q4_0
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_0
#elif INIT_SRC0_SHMEM_Q4_1
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q5_0
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#elif INIT_SRC0_SHMEM_Q5_1
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
#ifdef INIT_SRC0_SHMEM_Q5_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u);
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
let q_packed = load_u32_at_src0_aligned(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
@@ -277,461 +201,306 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q8_0
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#elif INIT_SRC0_SHMEM_Q8_1
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
}
}
#elif INIT_SRC0_SHMEM_MXFP4
let block_byte_base = src0_idx * 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
#endif
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#endif
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
const BLOCK_SIZE = 256u;
const NQ = 4u;
fn store_shmem_kquants(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
}
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 84u;
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
let scales_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 16u;
let dm_byte_base = block_byte_base + 80u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(d_packed[0]);
let dmin = f16(d_packed[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
// whole 2 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u),
);
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block);
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let dl = d * f16(scale & 0xFu);
let ml = dmin * f16(scale >> 4u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
#elif INIT_SRC0_SHMEM_Q3_K
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
let hmask_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 32u;
let scales_byte_base = block_byte_base + 96u;
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let d_all = load_f16_at_src0(block_byte_base + 108u);
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
let is = k_in_block / 16u;
let hmask_block = pos_in_chunk;
let hmask_shift_phase = k_in_block / 32u;
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
// low 2 bits (4 elems)
let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block);
let q_lo2_vec4 = vec4<f16>(
f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u)
);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
// high 1 bit (4 elems)
let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk);
let q_hi1_vec4 = vec4<f16>(
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u))
);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_vec4 = q_lo2_vec4 - q_hi1_vec4;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu;
let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u;
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 110u;
store_shmem_kquants(dl * q_vec4, elem_idx);
#elif INIT_SRC0_SHMEM_Q4_K
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qs_byte_base = block_byte_base + 16u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// whole 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q5_K
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qh_byte_base = block_byte_base + 16u;
let qs_byte_base = block_byte_base + 48u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 176u;
let qh_block = k_in_block % 32u;
let qh_shift_phase = sub_block;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
// low 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_lo4_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
// high 1 bit (4 elems)
let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block);
let qh_vec4 = vec4<f16>(
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u))
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q6_K
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
let ql_byte_base = block_byte_base;
let qh_byte_base = block_byte_base + 128u;
let scales_byte_base = block_byte_base + 192u;
let d_byte_base = block_byte_base + 208u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let d = load_f16_at_src0(d_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let chunk = k_in_block / 128u;
let ql_pos_in_chunk = (k_in_block % 128u) % 64u;
let qh_pos_in_chunk = (k_in_block % 128u) % 32u;
let sub_block = k_in_block / 16u;
let ql_shift_phase = (k_in_block % 128u) / 64u;
let qh_shift_phase = (k_in_block % 128u) / 32u;
let qh_byte = get_byte(qh_packed, l % 4u);
// low 4 bits (4 elems)
let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk);
let ql_lo4_vec4 = vec4<u32>(
(ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu
);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
// hi 2 bits (4 elems)
let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk);
let qh_hi2_vec4 = vec4<u32>(
((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u,
);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0);
let scale_byte = scales_byte_base + 1u * sub_block;
let scale_word = load_u32_at_src0_aligned(scale_byte);
let scale = get_byte_i32(scale_word, scale_byte & 3u);
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
#endif
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K
#endif // k-quants
#ifdef INIT_SRC0_SHMEM_IQ4_NL
const BLOCK_SIZE = 32u;
@@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
#endif // INIT_SRC0_SHMEM_IQ3_S
#ifdef INIT_SRC0_SHMEM_MXFP4
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 17u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
let e = ldexp(1.0, i32(eu8) - 128);
// store NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_MXFP4
+6 -4
View File
@@ -43,12 +43,14 @@ struct Params {
var<storage, read_write> src: array<f32>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
var i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
+7 -4
View File
@@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (flat_i >= params.ne) {
return;
}
var i = gid.x;
var i = flat_i;
let ne2 = params.ne2;
#ifdef DIAG
let ne1 = params.ne0;
@@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef INPLACE
src[params.offset_src + src_idx] = res;
#else
dst[params.offset_dst + gid.x] = res;
dst[params.offset_dst + flat_i] = res;
#endif
}
+39 -2
View File
@@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"IM2COL",
"IM2COL_BACK",
"IM2COL_3D",
"COL2IM_1D",
"CONV_2D",
"CONV_3D",
"CONV_2D_DW",
@@ -1080,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1141,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"im2col(x)",
"im2col_back(x)",
"im2col_3d(x)",
"col2im_1d(x)",
"conv_2d(x)",
"conv_3d(x)",
"conv_2d_dw(x)",
@@ -1190,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4541,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph(
return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
}
// ggml_col2im_1d
struct ggml_tensor * ggml_col2im_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int s0,
int oc,
int p0) {
GGML_ASSERT(ggml_is_matrix(a));
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16);
GGML_ASSERT(s0 > 0);
GGML_ASSERT(oc > 0);
GGML_ASSERT(p0 >= 0);
const int64_t K_OC = a->ne[0];
const int64_t T_in = a->ne[1];
const int64_t K = K_OC / oc;
const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0;
GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks
GGML_ASSERT(K > 0 && T_out > 0);
const int64_t ne[4] = { T_out, oc, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne);
int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_COL2IM_1D;
result->src[0] = a;
return result;
}
// ggml_conv_transpose_1d
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+30
View File
@@ -440,6 +440,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA4 = auto()
GEMMA4_ASSISTANT = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
@@ -537,6 +538,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
TOKEN_EMBD_NORM = auto()
MASKED_EMBD_CENTROIDS= auto()
MASKED_EMBD_ORDERING = auto()
TOKEN_TYPES = auto()
POS_EMBD = auto()
OUTPUT = auto()
@@ -897,6 +900,8 @@ class MODEL_TENSOR(IntEnum):
A_PER_DIM_K_SCALE = auto() # gemma4
A_PER_DIM_SCALE = auto() # gemma4
# nextn/mtp
NEXTN_PROJ_PRE = auto()
NEXTN_PROJ_POST = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
@@ -986,6 +991,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
@@ -1083,6 +1089,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
MODEL_TENSOR.TOKEN_TYPES: "token_types",
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: "masked_embd_centroids",
MODEL_TENSOR.MASKED_EMBD_ORDERING: "masked_embd_ordering",
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
@@ -1471,6 +1479,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
# NextN/MTP
MODEL_TENSOR.NEXTN_PROJ_PRE: "nextn.pre_projection",
MODEL_TENSOR.NEXTN_PROJ_POST: "nextn.post_projection",
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
@@ -2577,6 +2587,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_ASSISTANT: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.MASKED_EMBD_CENTROIDS,
MODEL_TENSOR.MASKED_EMBD_ORDERING,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.NEXTN_PROJ_PRE,
MODEL_TENSOR.NEXTN_PROJ_POST,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
+16
View File
@@ -37,6 +37,14 @@ class TensorNameMap:
"model.embed", # talkie
),
# Masked embeddings
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: (
"masked_embedding.centroids", # gemma-4 E2B/E4B assistants
),
MODEL_TENSOR.MASKED_EMBD_ORDERING: (
"masked_embedding.token_ordering", # gemma-4 E2B/E4B assistants
),
# Token type embeddings
MODEL_TENSOR.TOKEN_TYPES: (
"embeddings.token_type_embeddings", # bert nomic-bert
@@ -2367,6 +2375,14 @@ class TensorNameMap:
),
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_PROJ_PRE: (
"pre_projection",
),
MODEL_TENSOR.NEXTN_PROJ_POST: (
"post_projection",
),
MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),
+4
View File
@@ -388,6 +388,10 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
// a source/target/parent context
// can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts
struct llama_context * ctx_other;
};
struct llama_model_tensor_override {
+115
View File
@@ -0,0 +1,115 @@
{{- bos_token -}}
{%- set preserve_thinking = preserve_thinking | default(false) -%}
{%- macro format_arg_value(arg_value) -%}
{%- if arg_value is string -%}
{{- "'" + arg_value + "'" -}}
{%- elif arg_value is mapping -%}
{{- arg_value | tojson -}}
{%- else -%}
{{- arg_value | string -}}
{%- endif -%}
{%- endmacro -%}
{%- macro parse_content(content) -%}
{%- if content is string -%}
{{- content -}}
{%- else -%}
{%- set _ns = namespace(result="") -%}
{%- for item in content -%}
{%- if item["type"] == "image" -%}
{%- set _ns.result = _ns.result + "<image>" -%}
{%- elif item["type"] == "text" -%}
{%- set _ns.result = _ns.result + item["text"] -%}
{%- else -%}
{%- set _ns.result = _ns.result + item | tojson -%}
{%- endif -%}
{%- endfor -%}
{{- _ns.result -}}
{%- endif -%}
{%- endmacro -%}
{%- macro render_tool_calls(tool_calls) -%}
{%- set tool_calls_ns = namespace(tool_calls=[]) -%}
{%- for tool_call in tool_calls -%}
{%- set func_name = tool_call["function"]["name"] -%}
{%- set func_args = tool_call["function"]["arguments"] -%}
{%- set args_ns = namespace(arg_strings=[]) -%}
{%- for arg_name, arg_value in func_args.items() -%}
{%- set args_ns.arg_strings = args_ns.arg_strings + [arg_name + "=" + format_arg_value(arg_value)] -%}
{%- endfor -%}
{%- set tool_calls_ns.tool_calls = tool_calls_ns.tool_calls + [func_name + "(" + (args_ns.arg_strings | join(", ")) + ")"] -%}
{%- endfor -%}
{{- "<|tool_call_start|>[" + (tool_calls_ns.tool_calls | join(", ")) + "]<|tool_call_end|>" -}}
{%- endmacro -%}
{%- set ns = namespace(system_prompt="", last_user_index=-1) -%}
{%- if messages[0]["role"] == "system" -%}
{%- if messages[0].get("content") -%}
{%- set ns.system_prompt = parse_content(messages[0]["content"]) -%}
{%- endif -%}
{%- set messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%}
{%- for tool in tools -%}
{%- if tool is not string -%}
{%- set tool = tool | tojson -%}
{%- endif -%}
{%- set ns.system_prompt = ns.system_prompt + tool -%}
{%- if not loop.last -%}
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
{%- endif -%}
{%- endfor -%}
{%- set ns.system_prompt = ns.system_prompt + "]" -%}
{%- endif -%}
{%- if ns.system_prompt -%}
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
{%- endif -%}
{%- for message in messages -%}
{%- if message["role"] == "user" -%}
{%- set ns.last_user_index = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" -}}
{%- if message.role == "assistant" -%}
{%- generation -%}
{%- if message.thinking is defined and (preserve_thinking or loop.index0 > ns.last_user_index) -%}
{{- "<think>" + message.thinking + "</think>" -}}
{%- endif -%}
{%- set _cfm_tag = "CONTINUE_FINAL_MESSAGE_TAG " -%}
{%- set _has_cfm = false -%}
{%- if message.content is defined -%}
{%- set content = parse_content(message.content) -%}
{%- if not (preserve_thinking or loop.index0 > ns.last_user_index) -%}
{%- if "</think>" in content -%}
{%- set content = content.split("</think>")[-1] | trim -%}
{%- endif -%}
{%- endif -%}
{%- if message.tool_calls is defined and content.endswith(_cfm_tag) -%}
{%- set _has_cfm = true -%}
{%- set _trunc_len = (content | length) - (_cfm_tag | length) -%}
{{- content[:_trunc_len] -}}
{%- else -%}
{{- content -}}
{%- endif -%}
{%- endif -%}
{%- if message.tool_calls is defined -%}
{{- render_tool_calls(message.tool_calls) -}}
{%- endif -%}
{%- if _has_cfm -%}
{{- _cfm_tag -}}
{%- endif -%}
{{- "<|im_end|>\n" -}}
{%- endgeneration -%}
{%- else %}
{%- if message.get("content") -%}
{{- parse_content(message["content"]) -}}
{%- endif -%}
{{- "<|im_end|>\n" -}}
{%- endif %}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
+1 -1
View File
@@ -1 +1 @@
1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3
7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1
+9
View File
@@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
@@ -453,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
@@ -556,6 +559,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
};
// declare information about the model weight tensors:
@@ -765,6 +770,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.
@@ -778,6 +785,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
{LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
+6
View File
@@ -61,6 +61,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
@@ -557,14 +558,19 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
LLM_TENSOR_MASKED_EMBD_CENTROIDS,
LLM_TENSOR_MASKED_EMBD_ORDERING,
};
enum llm_tensor_layer {
LLM_TENSOR_LAYER_INPUT,
LLM_TENSOR_LAYER_REPEATING,
+37 -18
View File
@@ -69,9 +69,10 @@ llama_context::llama_context(
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.ctx_type = params.ctx_type;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -84,7 +85,17 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.ctx_type = params.ctx_type;
cparams.ctx_other = nullptr;
// TODO: more generic
if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (params.ctx_other == nullptr) {
// TODO: change from runtime_error to llama_exception to avoid printing error message
throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
@@ -300,10 +311,11 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type =*/ cparams.ctx_type,
/*.mem_other =*/ llama_get_memory(cparams.ctx_other),
};
memory.reset(model.create_memory(params_mem, cparams));
@@ -904,7 +916,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
throw std::runtime_error("no nextn embeddings");
}
const uint32_t n_embd = model.hparams.n_embd;
const uint32_t n_embd = model.hparams.n_embd_out();
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
@@ -1473,7 +1485,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}
@@ -1924,7 +1936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
@@ -2017,7 +2029,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
@@ -2036,12 +2047,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_nextn.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd_out * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -3375,6 +3386,7 @@ llama_context_params llama_context_default_params() {
/*.kv_unified =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.ctx_other =*/ nullptr,
};
return result;
@@ -3454,7 +3466,6 @@ llama_context * llama_init_from_model(
return nullptr;
}
try {
auto * ctx = new llama_context(*model, params);
return ctx;
@@ -3593,6 +3604,14 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;
}
return ctx->get_memory();
}
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();
@@ -3656,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve(
uint32_t n_tokens,
uint32_t n_seqs,
uint32_t n_outputs) {
auto * memory = ctx->get_memory();
auto memory = ctx->get_memory();
llama_memory_context_ptr mctx;
if (memory) {
mctx = memory->init_full();
@@ -3696,10 +3715,6 @@ int32_t llama_set_adapter_cvec(
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
@@ -4010,3 +4025,7 @@ void llama_opt_epoch(
llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) {
return ctx->memory_breakdown();
}
llama_context * llama_get_ctx_other(struct llama_context * ctx) {
return ctx->get_cparams().ctx_other;
}
+2 -1
View File
@@ -6,6 +6,7 @@
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@@ -273,7 +274,7 @@ private:
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
llama_memory_ptr memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};
+2
View File
@@ -49,4 +49,6 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
llama_context * ctx_other;
};
+2
View File
@@ -100,3 +100,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
+22 -8
View File
@@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -565,7 +565,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
if (self_kq_mask && self_kq_mask->buffer) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -573,7 +576,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
@@ -605,7 +610,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs && self_k_idxs->buffer) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
if (self_kq_mask && self_kq_mask->buffer) {
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
}
@@ -613,7 +620,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
}
@@ -756,7 +765,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
}
if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -764,7 +775,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
}
if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
@@ -810,18 +823,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
@@ -1006,6 +1019,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer()),
n_layer_nextn (hparams.n_layer_nextn),
n_rot (hparams.n_rot()),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),
@@ -1859,9 +1873,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
res->t_inp_embd = cur;
// For Granite architecture
// NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be
// multimodal inputs that should not be scaled.
if (ubatch.token && hparams.f_embedding_scale != 0.0f) {
// NOTE: For deepstack models, only apply scale to token inputs (ie text-only input).
// Raw embeddings are assumed to be multimodal inputs that should not be scaled.
if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) {
if (!ggml_is_contiguous(cur)) {
cur = ggml_cont(ctx0, cur);
}
+1
View File
@@ -784,6 +784,7 @@ struct llm_graph_context {
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_layer_nextn;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_head;
+4
View File
@@ -91,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const {
}
uint32_t llama_hparams::n_embd_inp() const {
if (n_embd_inp_impl > 0) {
return n_embd_inp_impl;
}
uint32_t n_embd_inp = n_embd;
if (n_deepstack_layers > 0) {
+4
View File
@@ -185,6 +185,9 @@ struct llama_hparams {
// for Classifiers
uint32_t n_cls_out = 1;
// input embedding dimension (0 = use n_embd)
uint32_t n_embd_inp_impl = 0;
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;
@@ -224,6 +227,7 @@ struct llama_hparams {
// complex mapping. If using deepstack_mapping_arr, also make sure to set
// n_deepstack_layers to the number of unique deepstack layers so that
// n_embd_imp is accurate (see granite.cpp).
// TODO: can be expressed via the `new n_embd_inp_impl` and remove this param
uint32_t n_deepstack_layers = 0;
// deepstack layer array (Granite4 Vision)
+2 -2
View File
@@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_mla = std::make_unique<llama_kv_cache>(
model, model.hparams, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
// we use llama_kv_cache for caching indexer keys
// by hand-tweaking some hparams we fool it to create
@@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_lid = std::make_unique<llama_kv_cache>(
model, hparams_lid, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
}
void llama_kv_cache_dsa::clear(bool data) {
+15 -3
View File
@@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
@@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
llama_memory_t mem_other_base = nullptr;
if (mem_other) {
mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base();
}
llama_memory_t mem_other_swa = nullptr;
if (mem_other) {
mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa();
}
kv_base = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share);
}
void llama_kv_cache_iswa::clear(bool data) {
+3 -1
View File
@@ -25,8 +25,10 @@ public:
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;
+124 -23
View File
@@ -90,10 +90,26 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
const layer_reuse_cb & reuse,
const layer_share_cb & share) :
model(model), hparams(hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type),
other(static_cast<llama_kv_cache *>(mem_other)),
v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()),
v_cells(*v_cells_impl) {
// shared cells view the source cache's K/V tensors, so the cell count
// follows the source allocation: a fitted target can be smaller than the
// draft default and oversized views would overflow the source tensors
if (other) {
const uint32_t size_other = other->get_size();
if (kv_size != size_other) {
LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other);
kv_size = size_other;
}
}
GGML_ASSERT(kv_size % n_pad == 0);
@@ -171,6 +187,24 @@ llama_kv_cache::llama_kv_cache(
continue;
}
if (share && other) {
const int32_t il_share = share(il);
if (il_share >= 0) {
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
layer_share.k->data, layer_share.v->data);
map_layer_ids[il] = layers.size();
layers.push_back(layer_share);
layers.back().il = il;
continue;
}
}
if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
@@ -282,29 +316,38 @@ llama_kv_cache::llama_kv_cache(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
n_embd_head_k_all = other->n_embd_head_k_all;
n_embd_head_v_all = other->n_embd_head_v_all;
attn_rot_k = other->attn_rot_k;
attn_rot_v = other->attn_rot_v;
} else {
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
@@ -347,6 +390,11 @@ void llama_kv_cache::clear(bool data) {
}
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
if (p0 < 0) {
@@ -410,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
}
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
@@ -497,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
}
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -519,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
@@ -564,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
}
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
@@ -598,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
}
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_min(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -606,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
}
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_max(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -746,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
}
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
bool updated = false;
auto * sched = lctx->get_sched();
@@ -1021,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
}
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -1815,6 +1903,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
GGML_ASSERT(!other);
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
@@ -1860,6 +1951,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
@@ -1925,6 +2021,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
}
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+10 -3
View File
@@ -98,7 +98,7 @@ public:
// likely through `struct llama_memory_params`
llama_kv_cache(
const llama_model & model,
const llama_hparams & hparams,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
@@ -109,8 +109,10 @@ public:
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache() = default;
@@ -264,7 +266,12 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells> v_cells;
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
llama_kv_cache * other;
std::shared_ptr<llama_kv_cells_vec> v_cells_impl;
llama_kv_cells_vec & v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
+2
View File
@@ -531,3 +531,5 @@ private:
}
}
};
using llama_kv_cells_vec = std::vector<llama_kv_cells>;
+2
View File
@@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
n_seq_max,
n_ubatch,
n_pad,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(
+2
View File
@@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
n_pad,
n_swa,
swa_type,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(
+4
View File
@@ -23,6 +23,8 @@ struct llama_memory_params {
bool swa_full;
llama_context_type ctx_type;
llama_memory_t mem_other;
};
enum llama_memory_status {
@@ -76,6 +78,8 @@ struct llama_memory_i {
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
using layer_share_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
+76 -35
View File
@@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_gemma3n(params);
case LLM_ARCH_GEMMA4:
return new llama_model_gemma4(params);
case LLM_ARCH_GEMMA4_ASSISTANT:
return new llama_model_gemma4_assistant(params);
case LLM_ARCH_GEMMA_EMBEDDING:
return new llama_model_gemma_embedding(params);
case LLM_ARCH_STARCODER2:
@@ -1205,7 +1207,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
const auto & use_mlock = params.use_mlock;
const auto & tensor_split = params.tensor_split;
const int n_layer = hparams.n_layer_all;
const int n_layer_all = hparams.n_layer_all;
const int n_gpu_layers = this->n_gpu_layers();
const bool use_mmap_buffer = true;
@@ -1262,10 +1264,10 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
splits[i] /= split_sum;
}
const int i_gpu_start = std::max(n_layer + 1 - n_gpu_layers, 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer + 1);
const int i_gpu_start = std::max(n_layer_all + 1 - n_gpu_layers, 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer_all + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
const bool is_swa = il < n_layer && hparams.is_swa(il);
const bool is_swa = il < n_layer_all && hparams.is_swa(il);
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa);
return {cpu_dev, &pimpl->cpu_buft_list};
@@ -1281,13 +1283,13 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list };
// assign the repeating layers to the devices according to the splits
pimpl->dev_layer.resize(n_layer);
for (int il = 0; il < n_layer; ++il) {
pimpl->dev_layer.resize(n_layer_all);
for (int il = 0; il < n_layer_all; ++il) {
pimpl->dev_layer[il] = get_layer_buft_list(il);
}
// assign the output layer
pimpl->dev_output = get_layer_buft_list(n_layer);
pimpl->dev_output = get_layer_buft_list(n_layer_all);
const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED;
@@ -1303,14 +1305,14 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
throw std::runtime_error("model has expert layers but no expert layers are used");
}
layers.resize(n_layer);
layers.resize(n_layer_all);
// call the per-model loading function
load_arch_tensors(ml);
// generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2)
// this avoids having to add scale loading to every architecture
for (int i = 0; i < n_layer; ++i) {
for (int i = 0; i < n_layer_all; ++i) {
auto & layer = layers[i];
// attention weight scales (per-tensor, shape {1})
@@ -1568,7 +1570,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
}
if (llama_supports_gpu_offload()) {
const int n_gpu = std::min(n_gpu_layers, n_layer);
const int n_gpu = std::min(n_gpu_layers, n_layer_all);
int n_repeating = n_gpu;
if (n_repeating > 0) {
@@ -1577,8 +1579,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) {
}
LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating);
const int max_backend_supported_layers = n_layer + 1;
const int max_offloadable_layers = n_layer + 1;
const int max_backend_supported_layers = n_layer_all + 1;
const int max_offloadable_layers = n_layer_all + 1;
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
}
@@ -1717,19 +1719,21 @@ void llama_model::print_info() const {
if (!hparams.vocab_only) {
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out());
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer());
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all);
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
@@ -1737,7 +1741,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
@@ -1764,7 +1768,7 @@ void llama_model::print_info() const {
[](const auto & entry) { return entry >= 0; })) {
LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__,
print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; },
hparams.n_layer()).c_str());
hparams.n_layer_all).c_str());
}
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
@@ -2113,8 +2117,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_recr */ std::move(filter_recr));
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_filter_cb filter = nullptr;
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_share_cb share = nullptr;
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
reuse = [&](uint32_t il) {
@@ -2143,20 +2148,53 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
llama_memory_t mem_other = llama_get_memory(cparams.ctx_other);
share = [&](int32_t il) {
const llama_model * model_other = llama_get_model(cparams.ctx_other);
if (hparams.is_swa(il)) {
return llama_model_n_layer(model_other) - 2;
}
return llama_model_n_layer(model_other) - 1;
};
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
mem_other,
filter,
reuse,
share);
} else {
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
nullptr,
filter,
reuse,
share);
}
} else {
GGML_ASSERT(!hparams.is_swa_any());
@@ -2173,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
filter,
nullptr,
nullptr);
}
}
@@ -2406,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
+5
View File
@@ -548,6 +548,10 @@ struct llama_model {
struct ggml_tensor * output_s = nullptr;
struct ggml_tensor * output_in_s = nullptr;
// NextN/MTP model-level projections
struct ggml_tensor * nextn_proj_pre = nullptr;
struct ggml_tensor * nextn_proj_post = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
@@ -702,6 +706,7 @@ const char * llm_type_name(llm_type type);
#define LLAMA_LOAD_LOCALS \
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
+203
View File
@@ -0,0 +1,203 @@
#include "models.h"
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
hparams.n_embd_inp_impl = hparams.n_embd_out();
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
uint32_t n_kv_shared_layers = 0;
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false);
GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl");
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
}
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
if (n_embd_head_k != n_embd_head_v) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
}
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
}
if (hparams.n_embd_out() == n_embd) {
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED);
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED);
const int64_t n_embd_backbone = hparams.n_embd_inp();
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
int rope_freqs_flag = 0;
for (int i = 0; i < n_layer_nextn; ++i) {
auto & layer = layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_ff = hparams.n_ff(i);
if (i == 0) {
nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0);
}
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
if (!hparams.is_swa(i)) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
rope_freqs_flag = TENSOR_DUPLICATED;
}
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_backbone = hparams.n_embd_inp();
ggml_tensor * inp_tokens;
ggml_tensor * inp_h;
{
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
inp_tokens = inp->tokens;
res->t_inp_tokens = inp->tokens;
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
cb(inp->embd, "inp_h", -1);
ggml_set_input(inp->embd);
inp_h = inp->embd;
res->t_inp_embd = inp->embd;
res->add_input(std::move(inp));
}
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
cb(xh, "inp_xh", -1);
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh);
cb(cur, "pre_proj", -1);
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer_nextn; ++il) {
const bool is_swa = hparams.is_swa(il);
const int64_t n_embd_head = hparams.n_embd_head_k(il);
const int64_t n_head = hparams.n_head(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
const int n_rot_l = hparams.n_rot(il);
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur_norm, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
if (il == n_layer_nextn - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
cb(attn_out, "attn_out", il);
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
ggml_tensor * logits = build_lora_mm(model.output, cur);
cb(logits, "result_output", -1);
res->t_logits = logits;
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur);
cb(h_next, "h_nextn", -1);
res->t_h_nextn = h_next;
ggml_build_forward_expand(gf, logits);
ggml_build_forward_expand(gf, h_next);
}
+18 -4
View File
@@ -155,12 +155,14 @@ public:
}
virtual ~llm_graph_input_logits_bias() = default;
void set_input(const llama_ubatch *) override {
void set_input(const llama_ubatch * /*ubatch*/) override {
const int64_t n_vocab = arr.size();
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
}
// bool can_reuse(const llm_graph_params & params) override;
bool can_reuse(const llm_graph_params & /*params*/) override {
return true;
}
ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]
@@ -270,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
}
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
if (il == n_layer - 1 && inp_out_ids) {
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
@@ -370,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
// TODO @ngxson : improve this
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
}
@@ -401,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
model.output_norm, nullptr,
LLM_NORM_RMS, -1);
// Expose the post-output-norm hidden state (the LM-head input feature) so that
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
// which feeds the drafter the target's post-final-norm hidden state.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cb(cur, "result_norm", -1);
res->t_embd = cur;
+13
View File
@@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
};
struct llama_model_gemma4_assistant : public llama_model_base {
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_gemma_embedding : public llama_model_base {
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
+5 -1
View File
@@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
for (uint32_t i = 0; i < hparams.n_layer(); ++i) {
hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0;
}
@@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % n_heads == 0);
GGML_ASSERT(n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+1
View File
@@ -265,6 +265,7 @@ if (NOT GGML_BACKEND_DL)
llama_build_and_test(test-quantize-fns.cpp)
llama_build_and_test(test-quantize-perf.cpp)
llama_build_and_test(test-rope.cpp)
llama_build_and_test(test-col2im-1d.cpp)
endif()
# libmtmd
+58
View File
@@ -5098,6 +5098,39 @@ struct test_conv_transpose_1d : public test_case {
}
};
// GGML_OP_COL2IM_1D
struct test_col2im_1d : public test_case {
const ggml_type type;
const int64_t K; // kernel size
const int64_t OC; // output channels
const int64_t T_in; // input length (number of columns)
const int s0; // stride
const int p0; // padding cropped from both sides
std::string vars() override {
return VARS_TO_STR6(type, K, OC, T_in, s0, p0);
}
double max_nmse_err() override {
return type == GGML_TYPE_F32 ? 1e-7 : 5e-4;
}
test_col2im_1d(ggml_type type = GGML_TYPE_F32,
int64_t K = 4, int64_t OC = 3, int64_t T_in = 7,
int s0 = 2, int p0 = 0)
: type(type), K(K), OC(OC), T_in(T_in), s0(s0), p0(p0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * cols = ggml_new_tensor_2d(ctx, type, K*OC, T_in);
ggml_set_name(cols, "cols");
ggml_tensor * out = ggml_col2im_1d(ctx, cols, s0, (int) OC, p0);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
// Dimensions
@@ -7771,6 +7804,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 384, 1, 1}, {3, 384, 384, 1}, 1, 0, 1, 0, 1, 0, false));
for (int s0 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int d0 : {1, 3}) {
@@ -8012,6 +8046,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
// ConvTranspose1d expressed as mul_mat + col2im (DAC decoder upsampling)
test_cases.emplace_back(new test_col2im_1d(type, 16, 32, 197, 8, 0)); // kernel = 2*stride
test_cases.emplace_back(new test_col2im_1d(type, 4, 3, 7, 2, 0));
test_cases.emplace_back(new test_col2im_1d(type, 1, 5, 13, 1, 0)); // stride 1, no overlap
test_cases.emplace_back(new test_col2im_1d(type, 6, 4, 11, 3, 1)); // with cropping
test_cases.emplace_back(new test_col2im_1d(type, 2, 3, 9, 3, 0)); // kernel < stride, gap positions are zeroed
test_cases.emplace_back(new test_col2im_1d(type, 5, 4, 11, 2, 0)); // kernel not a multiple of stride, alternating overlap
test_cases.emplace_back(new test_col2im_1d(type, 8, 4, 13, 4, 2)); // padding = stride/2 (DAC causal cropping)
test_cases.emplace_back(new test_col2im_1d(type, 4, 3, 1, 2, 0)); // single column, pure kernel unfold
test_cases.emplace_back(new test_col2im_1d(type, 16, 1, 197, 8, 0)); // OC = 1, mono output stage
test_cases.emplace_back(new test_col2im_1d(type, 1, 5, 13, 3, 0)); // K = 1 with stride > 1, sparse scatter
test_cases.emplace_back(new test_col2im_1d(type, 8, 2, 3, 2, 5)); // cropping eats most of the signal, T_out = 2
}
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
@@ -8525,6 +8574,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : all_types) {
test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a)));
}
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
@@ -9361,6 +9414,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}
// Memory bound overlap-add of the GEMM + col2im_1d transposed conv path, real vocoder stage shapes
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F32, 16, 512, 2048, 8, 0));
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F32, 4, 128, 65536, 2, 0));
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F16, 16, 512, 2048, 8, 0));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+133 -175
View File
@@ -1825,6 +1825,104 @@ static void test_convert_responses_to_chatcmpl() {
}
}
// Shared LFM2 parser cases - all variants use one output format and parser
static void test_lfm2_parser(const std::string & template_path, bool detailed_debug) {
auto tst = peg_tester(template_path, detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Python literals become JSON
tst.test("<|tool_call_start|>[toggle(enabled=True)]<|tool_call_end|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
.run();
tst.test("<|tool_call_start|>[set_nullable(value=None)]<|tool_call_end|>")
.tools({ nullable_tool })
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
.run();
// Nested Python literal
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": True, \"count\": 3})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "count": 3}})"))
.run();
// JSON literals are accepted too
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": true, \"note\": null})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "note": null}})"))
.run();
// Dotted function name with structured args
tst.test("<|tool_call_start|>[Calendar.create_event(title=\"demo\", participants=[\"Alice\", \"Bob\"], "
"metadata={\"priority\": \"high\", \"reminder\": true})]<|tool_call_end|>")
.tools({ calendar_create_event_tool })
.expect(message_with_tool_calls(
"Calendar.create_event",
R"({"title": "demo", "participants": ["Alice", "Bob"], "metadata": {"priority": "high", "reminder": true}})"))
.run();
// Markdown links stay content
tst.test("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org).")
.tools({ get_time_tool })
.expect(simple_assist_msg("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org)."))
.run();
// Python tool with multiline code in string
tst.test("<|tool_call_start|>[python(code=\"def hello():\\n print('hey')\")]<|tool_call_end|>")
.tools({ python_tool })
.expect_tool_calls({
{ "python", R"#({"code": "def hello():\\n print('hey')"})#", "" }
})
.run();
// Content before tool call (no reasoning)
tst.test("Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
))
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({ special_function_tool, special_function_tool_with_optional_param })
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
}
static void test_template_output_peg_parsers(bool detailed_debug) {
LOG_DBG("%s\n", __func__);
@@ -4038,49 +4136,30 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.run();
}
// LFM2-8B-A1B tests - uses <|tool_list_start|>/<|tool_list_end|> and <|tool_call_start|>[name(args)]<|tool_call_end|>
for (const char * tmpl : {
"models/templates/LFM2-8B-A1B.jinja",
"models/templates/LFM2.5-Instruct.jinja",
"models/templates/LFM2.5-8B-A1B.jinja",
}) {
test_lfm2_parser(tmpl, detailed_debug);
}
// Thinking cases only apply to LFM2.5-8B-A1B, the one LFM2 template that emits <think>
{
auto tst = peg_tester("models/templates/LFM2-8B-A1B.jinja", detailed_debug);
auto tst = peg_tester("models/templates/LFM2.5-8B-A1B.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Reasoning is parsed independent of enable_thinking
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Tool call with reasoning (enable_thinking=true)
// Tool call with reasoning
tst.test("<think>I'm\nthinking</think><|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with reasoning and content
tst.test("<think>I need to call a function</think>"
"Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
@@ -4088,32 +4167,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
))
.run();
// Python tool with multiline code in string
tst.test("<|tool_call_start|>[python(code=\"def hello():\\n print('hey')\")]<|tool_call_end|>")
.tools({ python_tool })
.expect_tool_calls({
{ "python", R"#({"code": "def hello():\\n print('hey')"})#", "" }
})
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// fake tool call marker in reasoning
tst.test(
"<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
// Fake tool call marker inside reasoning is not parsed as a call
tst.test("<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
@@ -4122,127 +4178,21 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
})
.run();
// Continuation tests
tst.test("world!\nWhat's up?")
// enable_thinking=false still captures emitted reasoning
tst.test("<think>I'm\nthinking</think>Hello, world!\nWhat's up?")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.messages({ message_user, message_assist_prefill_content })
.add_generation_prompt(false)
.continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT)
.expect_reasoning("I'm thinking")
.expect_content("Hello, world!\nWhat's up?")
.expect(message_assist_thoughts)
.run();
tst.test(" thinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.messages({ message_user, message_assist_prefill_reasoning })
.add_generation_prompt(false)
.continue_final_message(COMMON_CHAT_CONTINUATION_REASONING)
.expect_reasoning("I'm thinking")
.expect_content("Hello, world!\nWhat's up?")
.run();
}
// LFM2.5 tests - format <|tool_call_start|>[name(args)]<|tool_call_end|>
{
auto tst = peg_tester("models/templates/LFM2.5-Instruct.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Python literals become JSON.
tst.test("<|tool_call_start|>[toggle(enabled=True)]<|tool_call_end|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
.run();
tst.test("<|tool_call_start|>[set_nullable(value=None)]<|tool_call_end|>")
.tools({ nullable_tool })
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
.run();
// Nested Python literal.
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": True, \"count\": 3})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "count": 3}})"))
.run();
// JSON literals are accepted too.
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": true, \"note\": null})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "note": null}})"))
.run();
// Dotted function name with structured args.
tst.test("<|tool_call_start|>[Calendar.create_event(title=\"demo\", participants=[\"Alice\", \"Bob\"], "
"metadata={\"priority\": \"high\", \"reminder\": true})]<|tool_call_end|>")
.tools({ calendar_create_event_tool })
.expect(message_with_tool_calls(
"Calendar.create_event",
R"({"title": "demo", "participants": ["Alice", "Bob"], "metadata": {"priority": "high", "reminder": true}})"))
.run();
// Markdown links stay content.
tst.test("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org).")
.tools({ get_time_tool })
.expect(simple_assist_msg("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org)."))
.run();
// Tool call with reasoning (enable_thinking=true)
tst.test("<think>I'm\nthinking</think><|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with content before tool call
tst.test("Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
))
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// Continuation tests
// Continuation: prefill content
tst.test("world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
@@ -4253,6 +4203,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_content("Hello, world!\nWhat's up?")
.run();
// Continuation: prefill reasoning
tst.test(" thinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
@@ -5478,18 +5429,25 @@ static void test_template_generation_prompt() {
check(tmpls, continuation_reasoning(), "<|im_assistant|>assistant<|im_middle|><think>I'm");
}
{
auto tmpls = read_templates("models/templates/LFM2-8B-A1B.jinja");
for (const char * tmpl : {
"models/templates/LFM2-8B-A1B.jinja",
"models/templates/LFM2.5-Instruct.jinja",
"models/templates/LFM2.5-8B-A1B.jinja",
}) {
auto tmpls = read_templates(tmpl);
check(tmpls, basic(), "<|im_start|>assistant\n");
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>I'm thinking</think>Hello, ");
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>I'm");
}
{
auto tmpls = read_templates("models/templates/LFM2.5-Instruct.jinja");
check(tmpls, basic(), "<|im_start|>assistant\n");
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>I'm thinking</think>Hello, ");
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>I'm");
// 8B-A1B renders prior-turn reasoning via the "thinking" field
auto tmpls = read_templates("models/templates/LFM2.5-8B-A1B.jinja");
common_chat_templates_inputs inputs;
inputs.messages = { message_user, message_assist_call_thoughts, tool_msg };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
assert_contains(params.prompt, "<think>I'm\nthinking</think>");
}
{
+159
View File
@@ -0,0 +1,159 @@
// test-col2im-1d.cpp: validate GGML_OP_COL2IM_1D against ggml_conv_transpose_1d.
//
// A ConvTranspose1d factorizes as a GEMM followed by an overlap-add:
// conv_transpose_1d(w, x) equals col2im_1d(mul_mat(w_perm, x_t), s0, OC, p0)
// with w_perm the [IC, K*OC] permutation of the [K, OC, IC] kernel and x_t the
// [IC, T_in] transpose of the [T_in, IC] input. The test derives both alternative
// layouts from one logical weight and one logical input with graph ops only
// (permute + cont + reshape), runs the two paths on the CPU backend, and compares
// them in F32. The F16 and BF16 kernels are exercised by casting the column
// matrix before the scatter. Cropping (p0 > 0) is checked against the shifted
// slice of the uncropped reference, which conv_transpose_1d cannot express.
#include "ggml.h"
#include "ggml-cpu.h"
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <vector>
// One geometry: kernel size, output channels, input length, stride, crop
struct col2im_case {
int64_t K;
int64_t OC;
int64_t T_in;
int s0;
int p0;
};
// Mirrors the eval grid of test-backend-ops
static const col2im_case CASES[] = {
{ 16, 32, 197, 8, 0 }, // kernel = 2*stride, DAC upsampling shape
{ 4, 3, 7, 2, 0 },
{ 1, 5, 13, 1, 0 }, // stride 1, no overlap
{ 6, 4, 11, 3, 1 }, // with cropping
{ 2, 3, 9, 3, 0 }, // kernel < stride, gap positions are zeroed
{ 5, 4, 11, 2, 0 }, // kernel not a multiple of stride, alternating overlap
{ 8, 4, 13, 4, 2 }, // padding = stride/2, DAC causal cropping
{ 4, 3, 1, 2, 0 }, // single column, pure kernel unfold
{ 16, 1, 197, 8, 0 }, // OC = 1, mono output stage
{ 1, 5, 13, 3, 0 }, // K = 1 with stride > 1, sparse scatter
{ 8, 2, 3, 2, 5 }, // cropping eats most of the signal, T_out = 2
};
// Input channels of the GEMM, shared by every case
static const int64_t IC = 7;
// Deterministic LCG mapped to [-1, 1]
static uint64_t g_rng = 0x12345678ULL;
static float frand(void) {
g_rng = g_rng * 6364136223846793005ULL + 1442695040888963407ULL;
return (float)((g_rng >> 33) & 0xffffff) / (float)0x800000 - 1.0f;
}
// Read a F32/F16/BF16 tensor back as a flat F32 vector
static std::vector<float> tensor_to_f32(const struct ggml_tensor * t) {
const int64_t n = ggml_nelements(t);
std::vector<float> out(n);
if (t->type == GGML_TYPE_F32) {
memcpy(out.data(), t->data, n * sizeof(float));
} else if (t->type == GGML_TYPE_F16) {
for (int64_t i = 0; i < n; i++) {
out[i] = ggml_fp16_to_fp32(((const ggml_fp16_t *) t->data)[i]);
}
} else {
for (int64_t i = 0; i < n; i++) {
out[i] = ggml_bf16_to_fp32(((const ggml_bf16_t *) t->data)[i]);
}
}
return out;
}
// NMSE of the cropped output against the p0 shifted slice of the full reference
static double nmse_cropped(const float * y, const float * ref, int64_t T_out, int64_t T_ref, int64_t OC, int p0) {
double num = 0.0;
double den = 0.0;
for (int64_t oc = 0; oc < OC; oc++) {
for (int64_t t = 0; t < T_out; t++) {
const double a = y [t + oc * T_out];
const double b = ref[t + p0 + oc * T_ref];
num += (a - b) * (a - b);
den += b * b;
}
}
return num / (den + 1e-30);
}
int main(void) {
int fails = 0;
for (const col2im_case & c : CASES) {
const int64_t T_ref = (c.T_in - 1) * c.s0 + c.K;
const int64_t T_out = T_ref - 2 * c.p0;
struct ggml_init_params params = {
/* .mem_size = */ (size_t) 64 << 20,
/* .mem_base = */ NULL,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(params);
// One logical weight and one logical input feed both paths
struct ggml_tensor * w = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, c.K, c.OC, IC);
struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c.T_in, IC);
for (int64_t i = 0; i < ggml_nelements(w); i++) {
((float *) w->data)[i] = frand();
}
for (int64_t i = 0; i < ggml_nelements(x); i++) {
((float *) x->data)[i] = frand();
}
// Reference path: the native op, uncropped
struct ggml_tensor * y_ref = ggml_conv_transpose_1d(ctx, w, x, c.s0, 0, 1);
// Decomposed path: [K, OC, IC] -> [IC, K, OC] -> [IC, K*OC], k fastest inside each oc block
struct ggml_tensor * w_perm = ggml_cont(ctx, ggml_permute(ctx, w, 1, 2, 0, 3));
w_perm = ggml_reshape_2d(ctx, w_perm, IC, c.K * c.OC);
struct ggml_tensor * x_t = ggml_cont(ctx, ggml_transpose(ctx, x));
struct ggml_tensor * col = ggml_mul_mat(ctx, w_perm, x_t);
struct ggml_tensor * y32 = ggml_col2im_1d(ctx, col, c.s0, (int) c.OC, c.p0);
// Half precision kernels: the same columns cast before the scatter
struct ggml_tensor * y16 = ggml_col2im_1d(ctx, ggml_cast(ctx, col, GGML_TYPE_F16), c.s0, (int) c.OC, c.p0);
struct ggml_tensor * ybf = ggml_col2im_1d(ctx, ggml_cast(ctx, col, GGML_TYPE_BF16), c.s0, (int) c.OC, c.p0);
GGML_ASSERT(y_ref->ne[0] == T_ref && y_ref->ne[1] == c.OC);
GGML_ASSERT(y32->ne[0] == T_out && y32->ne[1] == c.OC);
struct ggml_cgraph * gf = ggml_new_graph(ctx);
ggml_build_forward_expand(gf, y_ref);
ggml_build_forward_expand(gf, y32);
ggml_build_forward_expand(gf, y16);
ggml_build_forward_expand(gf, ybf);
ggml_graph_compute_with_ctx(ctx, gf, 4);
const std::vector<float> f32 = tensor_to_f32(y32);
const std::vector<float> f16 = tensor_to_f32(y16);
const std::vector<float> fbf = tensor_to_f32(ybf);
const float * ref = (const float *) y_ref->data;
const double e32 = nmse_cropped(f32.data(), ref, T_out, T_ref, c.OC, c.p0);
const double e16 = nmse_cropped(f16.data(), ref, T_out, T_ref, c.OC, c.p0);
const double ebf = nmse_cropped(fbf.data(), ref, T_out, T_ref, c.OC, c.p0);
// Same thresholds as test-backend-ops: 1e-7 full precision, 5e-4 half
const bool ok = e32 <= 1e-7 && e16 <= 5e-4 && ebf <= 5e-4;
if (!ok) {
fails++;
}
printf("col2im_1d K=%2d OC=%2d T_in=%3d s0=%d p0=%d: nmse f32=%.2e f16=%.2e bf16=%.2e %s\n",
(int) c.K, (int) c.OC, (int) c.T_in, c.s0, c.p0, e32, e16, ebf, ok ? "OK" : "FAIL");
ggml_free(ctx);
}
printf(fails == 0 ? "all col2im_1d checks passed\n" : "%d col2im_1d checks FAILED\n", fails);
return fails == 0 ? 0 : 1;
}
+3 -3
View File
@@ -392,7 +392,7 @@ static bool arch_supported(const llm_arch arch) {
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
return false; // FIXME CUDA backend crashes.
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
return false; // FIXME @ngxson
}
if (arch == LLM_ARCH_LLAMA_EMBED || arch == LLM_ARCH_GEMMA_EMBEDDING || arch == LLM_ARCH_T5ENCODER) {
@@ -447,7 +447,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
for (bool moe : {false, true}) {
@@ -550,7 +550,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
+6
View File
@@ -2,6 +2,7 @@
#include <assert.h>
#include "mtmd.h"
#include "mtmd-helper.h"
int main(void) {
printf("\n\nTesting libmtmd C API...\n");
@@ -17,6 +18,11 @@ int main(void) {
return 1;
}
// simple test for the helper
size_t n_tokens_total = mtmd_helper_get_n_tokens(chunks);
printf("Total tokens in chunks: %zu\n", n_tokens_total);
assert(n_tokens_total > 0);
size_t n_chunks = mtmd_input_chunks_size(chunks);
printf("Number of chunks: %zu\n", n_chunks);
assert(n_chunks > 0);
+19 -3
View File
@@ -128,7 +128,18 @@ struct cli_context {
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);
console::spinner::stop();
while (true) {
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial && res_partial->is_begin) {
// this is the "send 200 status to client" signal in streaming mode
// skip, do not stop the spinner
result = rd.next(should_stop);
} else {
console::spinner::stop();
break;
}
}
std::string curr_content;
bool is_thinking = false;
@@ -224,7 +235,7 @@ struct cli_context {
};
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 7> cmds = {
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
@@ -232,6 +243,7 @@ static const std::array<std::string_view, 7> cmds = {
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
@@ -446,6 +458,9 @@ int llama_cli(int argc, char ** argv) {
if (inf.has_inp_audio) {
console::log(" /audio <file> add an audio file\n");
}
if (inf.has_inp_video) {
console::log(" /video <file> add a video file\n");
}
console::log("\n");
// interactive loop
@@ -542,7 +557,8 @@ int llama_cli(int argc, char ** argv) {
continue;
} else if (
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
std::string marker = ctx_cli.load_input_file(fname, true);
+4 -9
View File
@@ -33,12 +33,8 @@
#endif
static llama_context ** g_ctx;
static llama_model ** g_model;
static common_sampler ** g_smpl;
static common_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;
@@ -136,7 +132,6 @@ int llama_completion(int argc, char ** argv) {
llama_context * ctx = nullptr;
common_sampler * smpl = nullptr;
g_model = &model;
g_ctx = &ctx;
g_smpl = &smpl;
@@ -549,9 +544,9 @@ int llama_completion(int argc, char ** argv) {
int n_consumed = 0;
int n_session_consumed = 0;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
std::vector<int> input_tokens;
std::vector<int> output_tokens;
std::ostringstream output_ss;
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
// the first thing we will do is to output the prompt, so set color accordingly
@@ -989,7 +984,7 @@ int llama_completion(int argc, char ** argv) {
LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_INF("saved final session to %s, n_tokens = %ld\n", path_session.data(), session_tokens.size());
LOG_INF("saved final session to %s, n_tokens = %zu\n", path_session.data(), session_tokens.size());
}
+7
View File
@@ -1,5 +1,8 @@
# mtmd
set(MTMD_VIDEO ON CACHE BOOL "enable video support in mtmd (requires ffmpeg binary in PATH)")
# TODO: add MTMD_VIDEO_METHOD in the future to select between ffmpeg and other backends
find_package(Threads REQUIRED)
add_library(mtmd
@@ -63,6 +66,10 @@ target_include_directories(mtmd PRIVATE ../..)
target_include_directories(mtmd PRIVATE ../../vendor)
target_compile_features (mtmd PRIVATE cxx_std_17)
if (MTMD_VIDEO)
target_compile_definitions(mtmd PRIVATE MTMD_VIDEO)
endif()
if (BUILD_SHARED_LIBS)
set_target_properties (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(mtmd PRIVATE LLAMA_BUILD)
+3
View File
@@ -37,6 +37,9 @@ struct clip_graph {
float kq_scale; // TODO: maybe move this to hparams
const clip_flash_attn_type flash_attn_type;
// TODO [QWEN_VIDEO]: improve this in the future
int n_batch = 1;
ggml_context_ptr ctx0_ptr;
ggml_context * ctx0;
ggml_cgraph * gf;
+142 -8
View File
@@ -4,6 +4,7 @@
#include "gguf.h"
#include "clip.h"
#include <array>
#include <climits>
#include <cstdarg>
#include <cinttypes>
@@ -429,26 +430,158 @@ static projector_type clip_projector_type_from_string(const std::string & str) {
// RGB uint8 image
struct clip_image_u8 {
int nx;
int ny;
clip_image_size get_size() const {
return { nx, ny };
}
void set_size(clip_image_size size, bool is_placeholder) {
nx = size.width;
ny = size.height;
if (is_placeholder) {
buf.clear();
} else {
buf.resize((size_t) nx * (size_t) ny * 3);
}
}
void cpy_buf(const std::vector<uint8_t> & new_buf) {
buf = new_buf;
}
const std::vector<uint8_t> & get_ro_buf() const {
if (is_placeholder()) {
throw std::runtime_error("this clip_image_u8 is a placeholder");
}
return buf;
}
// note to contributors: NEVER add a get_rw_buf(), it is a DANGEROUS pattern. always use get_pixel / set_pixel for buffer manipulation
bool is_placeholder() const {
return buf.empty();
}
std::array<uint8_t, 3> get_pixel(int x, int y) const {
if (is_placeholder()) {
// return a dummy value, so that legacy code can still process image without errors
return { 0, 0, 0 };
}
int idx = (y * nx + x) * 3;
return { buf[idx], buf[idx + 1], buf[idx + 2] };
}
void set_pixel(int x, int y, const std::array<uint8_t, 3> & rgb) {
if (is_placeholder()) {
return; // no-op
}
int idx = (y * nx + x) * 3;
buf[idx] = rgb[0];
buf[idx + 1] = rgb[1];
buf[idx + 2] = rgb[2];
}
size_t n_elements() const {
return n_pixels() * 3;
}
private:
std::vector<uint8_t> buf;
int nx = 0;
int ny = 0;
size_t n_pixels() const {
return (size_t) nx * (size_t) ny;
}
};
// For images, buf.size() == nx*ny*3
// Memory layout: RGBRGBRGB...
// For seq, buf.size() == nx*ny*3*nt
// Memory layout: RGBRGB...RGBRGB... (nt times)
// For audio, only one channel is used, buf.size() == nx*ny
// nx will be n_frames and ny will be n_mel
struct clip_image_f32 {
int nx;
int ny;
std::vector<float> buf;
// marks the global view in e.g., DeepSeek-OCR Models
bool add_viewsep = false;
// whether a learned newline token should be appended after the image (eg Granite4 Vision)
// whether a learned newline (or EOI) token should be appended after the image (eg Granite4 Vision)
bool add_newline = false;
clip_image_size get_size() const {
return { nx_, ny_ };
}
int nx() const { return nx_; }
int ny() const { return ny_; }
void set_size(clip_image_size size, bool is_placeholder, bool is_audio) {
nx_ = size.width;
ny_ = size.height;
if (is_placeholder) {
buf.clear();
} else {
if (is_audio) {
buf.resize((size_t) nx_ * (size_t) ny_);
} else {
buf.resize((size_t) nx_ * (size_t) ny_ * 3);
}
}
}
void cpy_buf(const std::vector<float> & new_buf) {
buf = new_buf;
}
void from_u8(const clip_image_u8 & img) {
auto size = img.get_size();
nx_ = size.width;
ny_ = size.height;
if (img.is_placeholder()) {
buf.clear();
return; // no-op
}
buf.resize(img.n_elements());
const auto & u8_buf = img.get_ro_buf();
for (size_t i = 0; i < img.n_elements(); ++i) {
buf[i] = (float) u8_buf[i] / 255.0f;
}
}
size_t n_elements() const {
return n_pixels() * 3;
}
void normalize(const float mean[3], const float std[3]) {
if (is_placeholder()) {
return; // no-op
}
for (size_t i = 0; i < n_pixels(); ++i) {
buf[i * 3 + 0] = (buf[i * 3 + 0] - mean[0]) / std[0];
buf[i * 3 + 1] = (buf[i * 3 + 1] - mean[1]) / std[1];
buf[i * 3 + 2] = (buf[i * 3 + 2] - mean[2]) / std[2];
}
}
const std::vector<float> & get_ro_buf() const {
if (is_placeholder()) {
throw std::runtime_error("this clip_image_f32 is a placeholder");
}
return buf;
}
// note to contributors: NEVER add a get_rw_buf(), it is a DANGEROUS pattern
bool is_placeholder() const {
return buf.empty();
}
private:
std::vector<float> buf;
int nx_ = 0;
int ny_ = 0;
size_t n_pixels() const {
return (size_t) nx_ * (size_t) ny_;
}
};
//
@@ -496,6 +629,7 @@ static void clip_log_internal(enum ggml_log_level level, const char * format, ..
va_end(args);
}
#define LOG_TRC(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_DBG(...) clip_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_INF(...) clip_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) clip_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+139 -152
View File
@@ -39,12 +39,14 @@ static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::s
}
// PPM header: P6 format, width, height, and max color value
file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
const auto ppm_size = img.get_size();
file << "P6\n" << ppm_size.width << " " << ppm_size.height << "\n255\n";
// Write pixel data
for (size_t i = 0; i < img.buf.size(); i += 3) {
const auto & ppm_buf = img.get_ro_buf();
for (size_t i = 0; i < ppm_buf.size(); i += 3) {
// PPM expects binary data in RGB format, which matches our image buffer
file.write(reinterpret_cast<const char*>(&img.buf[i]), 3);
file.write(reinterpret_cast<const char*>(&ppm_buf[i]), 3);
}
file.close();
@@ -57,9 +59,10 @@ static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string&
return;
}
int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
const auto bmp_size = img.get_size();
int fileSize = 54 + 3 * bmp_size.width * bmp_size.height; // File header + info header + pixel data
int bytesPerPixel = 3;
int widthInBytes = img.nx * bytesPerPixel;
int widthInBytes = bmp_size.width * bytesPerPixel;
int paddingAmount = (4 - (widthInBytes % 4)) % 4;
int stride = widthInBytes + paddingAmount;
@@ -72,7 +75,7 @@ static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string&
};
// Total file size
fileSize = 54 + (stride * img.ny);
fileSize = 54 + (stride * bmp_size.height);
fileHeader[2] = (unsigned char)(fileSize);
fileHeader[3] = (unsigned char)(fileSize >> 8);
fileHeader[4] = (unsigned char)(fileSize >> 16);
@@ -94,14 +97,14 @@ static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string&
};
// Width and height in the information header
infoHeader[4] = (unsigned char)(img.nx);
infoHeader[5] = (unsigned char)(img.nx >> 8);
infoHeader[6] = (unsigned char)(img.nx >> 16);
infoHeader[7] = (unsigned char)(img.nx >> 24);
infoHeader[8] = (unsigned char)(img.ny);
infoHeader[9] = (unsigned char)(img.ny >> 8);
infoHeader[10] = (unsigned char)(img.ny >> 16);
infoHeader[11] = (unsigned char)(img.ny >> 24);
infoHeader[4] = (unsigned char)(bmp_size.width);
infoHeader[5] = (unsigned char)(bmp_size.width >> 8);
infoHeader[6] = (unsigned char)(bmp_size.width >> 16);
infoHeader[7] = (unsigned char)(bmp_size.width >> 24);
infoHeader[8] = (unsigned char)(bmp_size.height);
infoHeader[9] = (unsigned char)(bmp_size.height >> 8);
infoHeader[10] = (unsigned char)(bmp_size.height >> 16);
infoHeader[11] = (unsigned char)(bmp_size.height >> 24);
// Write file headers
file.write(reinterpret_cast<char*>(fileHeader), sizeof(fileHeader));
@@ -109,14 +112,14 @@ static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string&
// Pixel data
std::vector<unsigned char> padding(3, 0); // Max padding size to be added to each row
for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
for (int x = 0; x < img.nx; ++x) {
for (int y = bmp_size.height - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
for (int x = 0; x < bmp_size.width; ++x) {
// Each pixel
size_t pixelIndex = (y * img.nx + x) * 3;
const auto px = img.get_pixel(x, y);
unsigned char pixel[3] = {
img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
img.buf[pixelIndex + 1],
img.buf[pixelIndex]
px[2], // BMP stores pixels in BGR format
px[1],
px[0]
};
file.write(reinterpret_cast<char*>(pixel), 3);
}
@@ -129,12 +132,13 @@ static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string&
// debug function to convert f32 to u8
static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
dst.nx = src.nx;
dst.ny = src.ny;
dst.buf.resize(3 * src.nx * src.ny);
for (size_t i = 0; i < src.buf.size(); ++i) {
dst.buf[i] = static_cast<uint8_t>(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
dst.set_size(src.get_size(), false);
const auto & src_buf = src.get_ro_buf();
std::vector<uint8_t> dst_buf(src.n_elements());
for (size_t i = 0; i < src.n_elements(); ++i) {
dst_buf[i] = static_cast<uint8_t>(std::min(std::max(int(src_buf[i] * 255.0f), 0), 255));
}
dst.cpy_buf(dst_buf);
}
#endif
@@ -241,8 +245,8 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
proj_type(ctx->proj_type()),
img(img),
patch_size(hparams.patch_size),
n_patches_x(img.nx / patch_size),
n_patches_y(img.ny / patch_size),
n_patches_x(img.nx() / patch_size),
n_patches_y(img.ny() / patch_size),
n_patches(n_patches_x * n_patches_y),
n_embd(hparams.n_embd),
n_head(hparams.n_head),
@@ -278,8 +282,8 @@ void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const {
// siglip2 naflex
ggml_tensor * clip_graph::resize_position_embeddings(uint32_t interpolation_mode) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const int height = img.ny() / patch_size;
const int width = img.nx() / patch_size;
const uint32_t mode = interpolation_mode;
const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
@@ -310,11 +314,17 @@ ggml_tensor * clip_graph::build_vit(
std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos,
const build_vit_opts & opts
) {
// batch dim: inp is [n_embd, n_pos] (B==1) or [n_embd, n_pos, B] (multi-tile encode)
const int64_t B = inp->ne[2];
if (learned_pos_embd) {
inp = ggml_add(ctx0, inp, learned_pos_embd);
cb(inp, "pos_embed", -1);
}
// flatten batch; unflatten again in attention
inp = ggml_reshape_2d(ctx0, inp, n_embd, n_pos * B);
ggml_tensor * inpL = inp;
// pre-layernorm
@@ -344,20 +354,24 @@ ggml_tensor * clip_graph::build_vit(
cur = ggml_add(ctx0, cur, layer.qkv_b);
}
Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ 0);
// Q/K/V as [d_head, n_head, n_pos, B], the batch stride is cur->nb[1]*n_pos.
Qcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ 0);
Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, n_embd));
Kcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, n_embd));
Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
Vcur = ggml_view_4d(ctx0, cur, d_head, n_head, n_pos, B,
/* nb1 */ ggml_row_size(cur->type, d_head),
/* nb2 */ cur->nb[1],
/* nb3 */ cur->nb[1] * n_pos,
/* offset */ ggml_row_size(cur->type, 2 * n_embd));
if (layer.q_norm) {
GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]);
@@ -402,9 +416,9 @@ ggml_tensor * clip_graph::build_vit(
}
}
Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head_kv, n_pos);
Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head_kv, n_pos);
Qcur = ggml_reshape_4d(ctx0, Qcur, d_head, n_head, n_pos, B);
Kcur = ggml_reshape_4d(ctx0, Kcur, d_head, n_head_kv, n_pos, B);
Vcur = ggml_reshape_4d(ctx0, Vcur, d_head, n_head_kv, n_pos, B);
if (norm_per_head) {
if (layer.q_norm) {
@@ -434,6 +448,7 @@ ggml_tensor * clip_graph::build_vit(
cb(Vcur, "Vcur_normed", il);
}
// build_attn returns a flat 2D [n_embd, n_pos*B]
cur = build_attn(layer.o_w, layer.o_b,
Qcur, Kcur, Vcur, opts.attn_mask, kq_scale, il);
cb(cur, "attn_out", il);
@@ -505,6 +520,10 @@ ggml_tensor * clip_graph::build_vit(
if (model.post_ln_w) {
inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
}
// restore the batch dim
GGML_ASSERT(inpL->ne[1] % B == 0);
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, inpL->ne[1] / B, B);
return inpL;
}
@@ -523,7 +542,7 @@ ggml_tensor * clip_graph::build_inp() {
}
ggml_tensor * clip_graph::build_inp_raw(int channels) {
ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img.nx(), img.ny(), channels, n_batch);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
return inp_raw;
@@ -816,8 +835,8 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
GGML_ASSERT(scale_factor > 1);
const int n_embd = cur->ne[0];
int width = img.nx / patch_size;
int height = img.ny / patch_size;
int width = img.nx() / patch_size;
int height = img.ny() / patch_size;
// pad width and height to factor
const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
@@ -844,8 +863,6 @@ ggml_tensor * clip_graph::build_patch_merge_permute(ggml_tensor * cur, int scale
}
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
const clip_image_f32 & img = *imgs.entries[0];
std::unique_ptr<clip_graph> builder;
@@ -1005,6 +1022,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
GGML_ABORT("missing cgraph builder");
}
// TODO [QWEN_VIDEO]: improve this in the future
builder->n_batch = imgs.entries.size();
return builder->build();
}
@@ -2805,13 +2825,12 @@ struct clip_model_loader {
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
img->nx = hparams.warmup_image_size;
img->ny = hparams.warmup_image_size;
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
const int sz = hparams.warmup_image_size;
img->set_size({sz, sz}, false, false);
LOG_INF("%s: warmup with image size = %d x %d\n", __func__, sz, sz);
} else {
img->nx = hparams.warmup_audio_size;
img->ny = hparams.n_mel_bins;
LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
img->set_size({hparams.warmup_audio_size, hparams.n_mel_bins}, false, false);
LOG_INF("%s: warmup with audio size = %d\n", __func__, hparams.warmup_audio_size);
}
batch.entries.push_back(std::move(img));
warmup(ctx_clip, batch);
@@ -3108,12 +3127,6 @@ struct clip_image_f32_batch * clip_image_f32_batch_init() {
return new clip_image_f32_batch();
}
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
if (nx) *nx = img->nx;
if (ny) *ny = img->ny;
return img->buf.data();
}
void clip_image_size_free(struct clip_image_size * load_image_size) {
if (load_image_size == nullptr) {
return;
@@ -3134,7 +3147,7 @@ size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int id
LOG_ERR("%s: invalid index %d\n", __func__, idx);
return 0;
}
return batch->entries[idx]->nx;
return batch->entries[idx]->nx();
}
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
@@ -3142,7 +3155,7 @@ size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int id
LOG_ERR("%s: invalid index %d\n", __func__, idx);
return 0;
}
return batch->entries[idx]->ny;
return batch->entries[idx]->ny();
}
clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
@@ -3153,13 +3166,6 @@ clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batc
return batch->entries[idx].get();
}
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
img->nx = nx;
img->ny = ny;
img->buf.resize(3 * nx * ny);
memcpy(img->buf.data(), rgb_pixels, img->buf.size());
}
void clip_free(clip_ctx * ctx) {
if (ctx == nullptr) {
return;
@@ -3167,20 +3173,6 @@ void clip_free(clip_ctx * ctx) {
delete ctx;
}
// deprecated
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
const int32_t nx = ctx->model.hparams.image_size;
const int32_t ny = ctx->model.hparams.image_size;
return clip_embd_nbytes_by_img(ctx, nx, ny);
}
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
clip_image_f32 img;
img.nx = img_w;
img.ny = img_h;
return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
return ctx->model.hparams.image_size;
}
@@ -3211,9 +3203,9 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_HUNYUANVL:
case PROJECTOR_TYPE_YOUTUVL:
return (img->nx / params.patch_size) / 2;
return (img->nx() / params.patch_size) / 2;
case PROJECTOR_TYPE_STEP3VL:
return img->nx / (params.patch_size * params.n_merge);
return img->nx() / (params.patch_size * params.n_merge);
default:
break;
}
@@ -3233,9 +3225,9 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
case PROJECTOR_TYPE_PADDLEOCR:
case PROJECTOR_TYPE_HUNYUANVL:
case PROJECTOR_TYPE_YOUTUVL:
return (img->ny / params.patch_size) / 2;
return (img->ny() / params.patch_size) / 2;
case PROJECTOR_TYPE_STEP3VL:
return img->ny / (params.patch_size * params.n_merge);
return img->ny() / (params.patch_size * params.n_merge);
default:
break;
}
@@ -3247,7 +3239,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
// for models with fixed size image, the input image is already pre-processed and resized to square
int patch_size = params.patch_size;
int n_patches = (img->nx / patch_size) * (img->ny / patch_size);
int n_patches = (img->nx() / patch_size) * (img->ny() / patch_size);
projector_type proj = ctx->proj_type();
@@ -3313,14 +3305,14 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_YOUTUVL:
{
// dynamic size (2 conv, so double patch size)
int x_patch = img->nx / (params.patch_size * 2);
int y_patch = img->ny / (params.patch_size * 2);
int x_patch = img->nx() / (params.patch_size * 2);
int y_patch = img->ny() / (params.patch_size * 2);
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_STEP3VL:
{
int x_patch = img->nx / (params.patch_size * params.n_merge);
int y_patch = img->ny / (params.patch_size * params.n_merge);
int x_patch = img->nx() / (params.patch_size * params.n_merge);
int y_patch = img->ny() / (params.patch_size * params.n_merge);
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_GEMMA3:
@@ -3347,8 +3339,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
// dynamic size
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
int x_patch = CLIP_ALIGN(img->nx(), out_patch_size) / out_patch_size;
int y_patch = CLIP_ALIGN(img->ny(), out_patch_size) / out_patch_size;
n_patches = x_patch * y_patch;
} break;
case PROJECTOR_TYPE_PADDLEOCR:
@@ -3364,8 +3356,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
{
// dynamic size
int n_merge = ctx->model.hparams.n_merge;
int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_x = img->nx() / patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny() / patch_size / (n_merge > 0 ? n_merge : 1);
if (ctx->model.token_embd_img_break) {
n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} else {
@@ -3378,7 +3370,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_MERALION:
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
{
n_patches = img->nx;
n_patches = img->nx();
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
if (ctx->model.audio_has_stack_frames()) {
@@ -3400,11 +3392,11 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
// chunk_size=100 frames --> 3x stride-2 conv2d --> 13 tokens per chunk
const int chunk_size = 100;
const int tokens_per_chunk = 13;
n_patches = (img->nx / chunk_size) * tokens_per_chunk;
n_patches = (img->nx() / chunk_size) * tokens_per_chunk;
} break;
case PROJECTOR_TYPE_GLMA:
{
n_patches = img->nx;
n_patches = img->nx();
// whisper downscales input token by half after conv1d
n_patches /= 2;
// reshape by merge_factor
@@ -3431,8 +3423,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
case PROJECTOR_TYPE_HUNYUANVL:
{
int merge = ctx->model.hparams.n_merge;
int ow = (img->nx / patch_size) / merge;
int oh = (img->ny / patch_size) / merge;
int ow = (img->nx() / patch_size) / merge;
int oh = (img->ny() / patch_size) / merge;
n_patches = (ow + 1) * oh + 2;
} break;
case PROJECTOR_TYPE_DEEPSEEKOCR2:
@@ -3446,13 +3438,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_LFM2A:
{
n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2;
n_patches = ((((img->nx() + 1) / 2) + 1) / 2 + 1) / 2;
} break;
case PROJECTOR_TYPE_GEMMA4A:
{
// Two Conv2D stride-2: O = floor((I + 2p - k) / s) + 1, p=1, k=3, s=2
// O = floor((I - 1) / 2) + 1
int n = img->nx;
int n = img->nx();
for (int i = 0; i < 2; i++) {
n = (n - 1) / 2 + 1;
}
@@ -3460,13 +3452,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} break;
case PROJECTOR_TYPE_GEMMA4UA:
{
n_patches = img->nx; // no downsampling: one token per raw waveform frame
n_patches = img->nx(); // no downsampling: one token per raw waveform frame
} break;
case PROJECTOR_TYPE_GRANITE_SPEECH:
{
const int ws = ctx->model.hparams.audio_proj_window_size;
const int ds = ctx->model.hparams.audio_proj_downsample_rate;
n_patches = ((img->nx + ws - 1) / ws) * (ws / ds);
n_patches = ((img->nx() + ws - 1) / ws) * (ws / ds);
} break;
case PROJECTOR_TYPE_GRANITE4_VISION:
{
@@ -3475,7 +3467,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
// For 384×384 input: n = 24/8 = 3, query_side = 4 → 144.
const int window_side = ctx->model.hparams.downsample_window_side;
const int query_side = ctx->model.hparams.downsample_query_side;
const int side = img->nx / params.patch_size;
const int side = img->nx() / params.patch_size;
const int n = side / window_side;
n_patches = (query_side * n) * (query_side * n);
if (img->add_newline) {
@@ -3503,12 +3495,15 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
const clip_image_f32_batch & imgs = *imgs_c_ptr;
int batch_size = imgs.entries.size();
int n_batch_cur = imgs.entries.size();
// maximum supported batch size, usually == 2 for qwen-vl-based models
int n_batch_max = clip_model_n_batch_max(ctx);
// TODO @ngxson : implement batch size > 1 as a loop
// we don't need true batching support because the cgraph will gonna be big anyway
if (batch_size != 1) {
return false; // only support batch size of 1
if (n_batch_cur > n_batch_max) {
return false;
}
// if buffers are not allocated, we need to do a warmup run to allocate them
@@ -3525,8 +3520,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
const auto & model = ctx->model;
const auto & hparams = model.hparams;
const int image_size_width = imgs.entries[0]->nx;
const int image_size_height = imgs.entries[0]->ny;
const int image_size_width = imgs.entries[0]->nx();
const int image_size_height = imgs.entries[0]->ny();
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
@@ -3546,7 +3541,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
return inp;
};
auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector<float> & values) {
auto set_input_f32 = [&get_inp_tensor](const char * name, const std::vector<float> & values) {
ggml_tensor * cur = get_inp_tensor(name);
GGML_ASSERT(cur->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
@@ -3564,7 +3559,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
if (!imgs.is_audio) {
size_t nelem = 0;
for (const auto & img : imgs.entries) {
nelem += img->nx * img->ny * 3;
nelem += img->nx() * img->ny() * 3;
}
std::vector<float> inp_raw(nelem);
@@ -3579,20 +3574,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// └─────┘ │
// ──────┘ x B
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
const int n = nx * ny;
// IMPORTANT: [QWEN_VIDEO] the batch dim is currently used for temporal dim in Qwen-VL models
// All entries must have the same spatial size (enforced by can_batch_with() during merging)
{
const int nx = imgs.entries[0]->nx();
const int ny = imgs.entries[0]->ny();
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
for (int b = 0; b < n_batch_cur; b++) {
const auto & buf = imgs.entries[b]->get_ro_buf();
float * batch_entry = inp_raw.data() + b * (3*n);
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
size_t base_src = 3*(y * nx + x); // idx of the first channel
size_t base_dst = y * nx + x; // idx of the first channel
batch_entry[ base_dst] = imgs.entries[b]->buf[base_src ];
batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
size_t base_src = 3*(y * nx + x);
size_t base_dst = y * nx + x;
batch_entry[ base_dst] = buf[base_src ];
batch_entry[1*n + base_dst] = buf[base_src + 1];
batch_entry[2*n + base_dst] = buf[base_src + 2];
}
}
}
@@ -3602,12 +3600,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
} else {
// audio input
GGML_ASSERT(imgs.entries.size() == 1);
const auto & mel_inp = imgs.entries[0];
const int n_step = mel_inp->nx;
const int n_mel = mel_inp->ny;
std::vector<float> inp_raw(n_step * n_mel);
std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
set_input_f32("inp_raw", inp_raw);
const auto & buf = mel_inp->get_ro_buf();
const int n_step = mel_inp->nx();
const int n_mel = mel_inp->ny();
GGML_ASSERT((size_t)n_step * n_mel == buf.size());
set_input_f32("inp_raw", buf);
}
// set input per projector
@@ -4218,7 +4218,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(imgs.entries.size() == 1);
const auto & img0 = imgs.entries.front();
// Compute n_pos matching SSCP output: two stride-2 convs
int n_pos = img0->nx;
int n_pos = img0->nx();
for (int i = 0; i < 2; i++) { n_pos = (n_pos - 1) / 2 + 1; }
// Chunked local attention: blocked causal mask and RPE
@@ -4324,7 +4324,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// reshapes as ggml_get_rows gathers. The names are set
// by g4v_gather() in models/granite4-vision.cpp.
const int patch_size = model.hparams.patch_size;
const int image_side = imgs.entries.front()->nx / patch_size;
const int image_side = imgs.entries.front()->nx() / patch_size;
const int window_side = hparams.downsample_window_side;
const int query_side = hparams.downsample_query_side;
const int n = image_side / window_side;
@@ -4570,17 +4570,15 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
return ctx->model.modality == CLIP_MODALITY_AUDIO;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);
for (int i = 0; i < h*w*3; i++)
{
clip_img.buf[i] = img[i];
int clip_model_n_batch_max(const struct clip_ctx * ctx) {
switch (ctx->proj_type()) {
case PROJECTOR_TYPE_QWEN2VL:
case PROJECTOR_TYPE_QWEN25VL:
case PROJECTOR_TYPE_QWEN3VL:
return 2;
default:
return 1;
}
clip_img.nx = w;
clip_img.ny = h;
clip_image_encode(ctx, n_threads, &clip_img, vec);
return true;
}
//
@@ -4591,17 +4589,6 @@ projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
return ctx->proj_type();
}
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
clip_image_f32 * audio = new clip_image_f32;
audio->nx = n_frames;
audio->ny = n_mel;
audio->buf.resize(n_frames * n_mel);
std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
batch->entries.push_back(clip_image_f32_ptr(audio));
batch->is_audio = true;
}
const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) {
return &ctx->model.hparams;
}
+11 -17
View File
@@ -17,6 +17,15 @@ struct clip_ctx;
struct clip_image_size {
int width;
int height;
bool operator==(const clip_image_size & other) const {
return width == other.width && height == other.height;
}
bool operator!=(const clip_image_size & other) const {
return !(*this == other);
}
int area() const {
return width * height;
}
};
struct clip_image_f32;
@@ -54,9 +63,6 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params
void clip_free(struct clip_ctx * ctx);
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
int32_t clip_get_image_size (const struct clip_ctx * ctx);
int32_t clip_get_patch_size (const struct clip_ctx * ctx);
int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
@@ -79,9 +85,6 @@ struct clip_image_u8 * clip_image_u8_init (void);
struct clip_image_f32 * clip_image_f32_init(void);
struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
// nx, ny are the output image dimensions
unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
void clip_image_size_free (struct clip_image_size * img_size);
void clip_image_u8_free (struct clip_image_u8 * img);
void clip_image_f32_free(struct clip_image_f32 * img);
@@ -94,12 +97,6 @@ size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int id
size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
*/
void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
@@ -107,14 +104,11 @@ bool clip_is_llava(const struct clip_ctx * ctx);
// note for contributor: this clip_is_(model) pattern is deprecated
// do NOT add new functions like this
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
// use by audio input
void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
bool clip_has_vision_encoder(const struct clip_ctx * ctx);
bool clip_has_audio_encoder(const struct clip_ctx * ctx);
int clip_model_n_batch_max(const struct clip_ctx * ctx);
std::map<ggml_backend_dev_t, size_t> clip_get_mem_usage(const struct clip_ctx * ctx);
struct clip_cap {
+1 -1
View File
@@ -1,7 +1,7 @@
#include "models.h"
ggml_cgraph * clip_graph_conformer::build() {
const int n_frames = img.nx;
const int n_frames = img.nx();
const int n_pos = n_frames / 2;
const int n_pos_embd = (((((n_frames + 1) / 2) + 1) / 2 + 1) / 2) * 2 - 1;
GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
+2 -2
View File
@@ -22,8 +22,8 @@ ggml_cgraph * clip_graph_exaone4_5::build() {
ggml_tensor * inp_raw = build_inp_raw();
ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
{
ggml_tensor * inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+2 -2
View File
@@ -16,8 +16,8 @@ ggml_cgraph * clip_graph_glm4v::build() {
ggml_set_name(positions, "positions");
ggml_set_input(positions);
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
// second conv dimension
{
+1 -1
View File
@@ -1,7 +1,7 @@
#include "models.h"
ggml_cgraph * clip_graph_granite_speech::build() {
const int n_frames = img.nx;
const int n_frames = img.nx();
const int context_size = hparams.audio_chunk_size;
const int ctc_layer = n_layer / 2;
const int conv_kernel = hparams.audio_conv_kernel_size;
+2 -2
View File
@@ -7,8 +7,8 @@
// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
const int height = img.ny() / patch_size;
const int width = img.nx() / patch_size;
const uint32_t mode = interpolation_mode;
GGML_ASSERT(pos_embd);
+2 -2
View File
@@ -56,8 +56,8 @@ ggml_cgraph * clip_graph_mimovl::build() {
patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_add(ctx0, inp, inp_1);
GGML_ASSERT(img.nx % (patch_size * 2) == 0);
GGML_ASSERT(img.ny % (patch_size * 2) == 0);
GGML_ASSERT(img.nx() % (patch_size * 2) == 0);
GGML_ASSERT(img.ny() % (patch_size * 2) == 0);
inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); // [w,h,c,b] -> [c,w,h,b]
inp = ggml_cont_4d(ctx0, inp, n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
+3 -2
View File
@@ -31,10 +31,11 @@ struct clip_graph_pixtral : clip_graph {
struct clip_graph_qwen2vl : clip_graph {
clip_graph_qwen2vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
ggml_cgraph * build() override;
ggml_tensor * build_inp_with_temporal_merge();
};
struct clip_graph_qwen3vl : clip_graph {
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
struct clip_graph_qwen3vl : clip_graph_qwen2vl {
clip_graph_qwen3vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph_qwen2vl(ctx, img) {}
ggml_cgraph * build() override;
};

Some files were not shown because too many files have changed in this diff Show More