Compare commits

...

27 Commits

Author SHA1 Message Date
Reese Levine 3ac3c20c96 ggml-webgpu: Add clang-format job (#24308)
* Add clang-format job

* try local formatting
2026-06-08 20:54:24 -07:00
Masashi Yoshimura 1e1aca09da ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (#24225)
* ggml-webgpu: Improve prefill speeds + refactor matmul for quants

* Fixes for editroconfig checker
2026-06-08 15:19:56 -07:00
Max Krasnyansky 7d2b45b4f7 mtp: support for gemma-4 E2B and E4B assistants (#24282)
* models: update converter to support smaller assistants

* models: add masked_embd tensors to gemma4-assist arch

* gemma-4: remove temp debug for conversion

* gemma-4-mtp: filter out masked_embedding tensors during conversion
2026-06-08 13:48:52 -07:00
Aldehir Rojas 42a0afd594 server : do not parse when flushing http headers (#24281) 2026-06-08 13:32:41 -05:00
Pascal a66d50588b graph: guard iswa kq_mask on its own buffer (#24294)
A SWA-only draft head (e.g. StepFun MTP) leaves the base sub-cache
empty, so its kq_mask buffer stays null and asserts at load. Guard
each mask on its own buffer in set_input and can_reuse, base and swa.

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-08 19:20:28 +02:00
Nikhil Jain 1705d434f6 [ggml-webgpu] Handle buffer overlap / buffer aliasing for concat operator (#24000)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* handle buffer overlap case for concat operator

* restore build-webgpu.yml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* Run clang-format

* Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
2026-06-08 08:07:31 -07:00
Nikhil Jain 3b3da01dc2 [ggml-webgpu] Implement 2D workgroups for scale, binary, and unary ops (#24044)
* Only run webgpu CI on my fork

* Add webgpu only workflow

* Implement 2d workgroups for more operations

* fix

* Fix type

* Move back to global_invocation_id
2026-06-08 08:07:15 -07:00
Xuan-Son Nguyen 3ebe862b5d docker: install ffmpeg in the released image (#24302) 2026-06-08 16:59:57 +02:00
Xuan-Son Nguyen 8f83d6c271 mtmd : add video input support (#24269)
* wip

* ok: lazy bitmap API

* remember to free lazy text

* wip

* add mtmd_helper_video

* support video input on server (base64 input)

* add MTMD_VIDEO config

* add timestamp

* update CLI

* cli: allow auto-completion for video

* add --video arg

* fix build

* update docs

* rename as suggested
2026-06-08 14:40:12 +03:00
Georgi Gerganov c2b1518fd4 sync : ggml 2026-06-08 14:31:33 +03:00
Georgi Gerganov 6a1de6fbf1 ggml : bump version to 0.14.0 (ggml/1533) 2026-06-08 14:31:33 +03:00
Xuan-Son Nguyen 715b86a366 cli: fix spinner not show during prompt processing (#24283) 2026-06-08 11:11:45 +02:00
Jeff Bolz c74759a244 vulkan: Use cm2 decode_vector for mul_mat_id B matrix loads (#23991)
This allows vec4 loads of the B elements. Also increase BK to 64 when this is
enabled. Neither of these alone is consistently faster, but together these give
a nice speedup.

In ggml-vulkan.cpp, we need to make sure the B matrix alignment and stride are
multiples of 4.
2026-06-08 10:40:37 +02:00
Ruben Ortlam 0f7fada56b cuda: reset cuda context after reading memory size (#23935)
* cuda: reset device in get_memory function if no backend is active

* also count device and host buffers

* exclude hip and musa from counting and device reset

* use device mutex instead of atomic

* undo backend_free function move
2026-06-08 10:22:44 +02:00
Harkirat Gill 19bba67c1f HIP: add gfx1152 and gfx1153 to RDNA3.5 (#24129) 2026-06-08 08:33:23 +02:00
Xuan-Son Nguyen daf6bc9f2d metal : fix im2col 1D case (audio models) (#24220) 2026-06-08 09:03:18 +03:00
Neo Zhang d403f00ec3 [SYCL] Update compute runtime version to 26.x in docker (#24070)
* update compute runtime from 25 to 26 in docker

* add comment with old driver for multiple GPUs
2026-06-08 10:35:18 +08:00
ddh0 9e3b928fd8 common : relax sampler name matching (#23744)
* common : relax sampler name matching

Currently, in some cases, the alternative names for samplers (like
`top-k` and `min-p` instead of the canonical `top_k` and `min_p`) are
not always recognized by the `common_sampler_types_from_names` function
in `common/sampling.cpp`.

This PR changes the signature of this function to remove the `bool
allow_alt_names` flag, and removes all occurences of the flag from call
sites. Therefore, the function will now always match all known names.

I also changed the logic of the function to unconditionally check the
provided sampler names against both the canonical and alternative names,
and to be case-insensitive.

This fixes an issue I was seeing wherein samplers specified in the
`llama-server` UI were not recognized as valid when the alternative
names were used.

* add more alt names

* cont. fix

* cast to unsigned char for correctness

* common : unify sampler name mapping

* annotate canonical vs. alt sampler name mappings per @CISC

* Update common/sampling.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* common : auto-generate sampler name aliases per @ngxson

* use merged map for matching

* use `.merge` instead of iterating

* nit: simplify comment

* nit: use insert everywhere, not index assignment

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-06-07 22:48:11 +02:00
David Friehs 8a963fc10e convert : fix conversion for Mistral-Medium-3.5-128B (#24268)
Mistral explicitly sets `moe` and `llama_4_scaling` to `null` in
params.json, breaking `key in dict` checks during conversion. Replace
with `dict.get(key) is not None` where this matters.

Fixes `convert-hf-to-gguf.py --mistral-format Mistral-Medium-3.5-128B`
2026-06-07 21:41:39 +02:00
Georgi Gerganov 379ac6673b kv-cache : avoid kv cells copies (#24277) 2026-06-07 21:42:54 +03:00
Pascal f0156d1401 kv-cache: follow the source cache size when sharing cells (#24267)
A fitted target context can end up smaller than the draft default, the
oversized assistant views then overflow the shared K/V tensors and trip
the ggml_view_4d size assert during graph reserve.
2026-06-07 18:33:00 +03:00
Aman Gupta 04eb4c446d llama : add Gemma4 MTP (#23398) 2026-06-07 20:50:54 +08:00
Sigbjørn Skjæret 8a091c47ab spec : fix vocab compatibility check (#24256) 2026-06-07 14:43:52 +03:00
konradmb 465b1f0e75 arg: Skip mmproj download when user supplied mmproj (#24239) 2026-06-07 11:18:44 +02:00
Sigbjørn Skjæret f71af352a5 convert : fix Gemma4 with no audio encoder (#24242) 2026-06-07 08:43:05 +02:00
Sigbjørn Skjæret 3f7c79d7b5 docker : bump cuda13 to 13.3.0 (#24228) 2026-06-07 08:31:58 +02:00
Tarek Dakhran 98d5e8ba8a common/chat : fix LFM2/LFM2.5 reasoning round-trip and <think> leak (#24234)
* common/chat : fix LFM2 reasoning round-trip and stray <think> leak
* Gate by reasoning format and whether the template supports <think>
2026-06-06 22:39:21 +02:00
82 changed files with 2551 additions and 1119 deletions
+1 -1
View File
@@ -53,7 +53,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -59,7 +59,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+16 -6
View File
@@ -57,11 +57,21 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.url=$IMAGE_URL \
org.opencontainers.image.source=$IMAGE_SOURCE
ARG IGC_VERSION=v2.20.5
ARG IGC_VERSION_FULL=2_2.20.5+19972
ARG COMPUTE_RUNTIME_VERSION=25.40.35563.10
ARG COMPUTE_RUNTIME_VERSION_FULL=25.40.35563.10-0
ARG IGDGMM_VERSION=22.8.2
#Following versions are for multiple GPUs, since 26.x has known issue:
# https://github.com/ggml-org/llama.cpp/issues/21747,
# https://github.com/intel/compute-runtime/issues/921.
#ARG IGC_VERSION=v2.20.5
#ARG IGC_VERSION_FULL=2_2.20.5+19972
#ARG COMPUTE_RUNTIME_VERSION=25.40.35563.10
#ARG COMPUTE_RUNTIME_VERSION_FULL=25.40.35563.10-0
#ARG IGDGMM_VERSION=22.8.2
ARG IGC_VERSION=v2.34.4
ARG IGC_VERSION_FULL=2_2.34.4+21428
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
ARG IGDGMM_VERSION=22.10.0
RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-core-${IGC_VERSION_FULL}_amd64.deb \
&& wget https://github.com/intel/intel-graphics-compiler/releases/download/$IGC_VERSION/intel-igc-opencl-${IGC_VERSION_FULL}_amd64.deb \
@@ -75,7 +85,7 @@ RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
&& dpkg --install *.deb
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -64,7 +64,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -107,7 +107,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \
&& apt-get install -y libgomp1 libtbb12 curl wget ffmpeg ocl-icd-libopencl1 \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -76,7 +76,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl \
&& apt-get install -y libgomp1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+1 -1
View File
@@ -49,7 +49,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
&& apt-get install -y libgomp1 curl ffmpeg libvulkan1 mesa-vulkan-drivers \
libglvnd0 libgl1 libglx0 libegl1 libgles2 \
&& apt autoremove -y \
&& apt clean -y \
+1 -1
View File
@@ -46,7 +46,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
org.opencontainers.image.source=$IMAGE_SOURCE
RUN apt-get update \
&& apt-get install -y libgomp1 libnuma1 curl \
&& apt-get install -y libgomp1 libnuma1 curl ffmpeg \
&& apt autoremove -y \
&& apt clean -y \
&& rm -rf /tmp/* /var/tmp/* \
+23
View File
@@ -35,6 +35,29 @@ env:
LLAMA_ARG_LOG_TIMESTAMPS: 1
jobs:
format:
runs-on: ubuntu-24.04
steps:
- name: Clone
uses: actions/checkout@v6
- name: Install clang-format 22
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key |
sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
sudo add-apt-repository -y \
"deb http://apt.llvm.org/noble/ llvm-toolchain-noble-22 main"
sudo apt-get update
sudo apt-get install -y clang-format-22
- name: Check formatting
run: |
find ggml/src/ggml-webgpu \
-type f \( -name '*.cpp' -o -name '*.hpp' -o -name '*.h' \) \
-print0 |
xargs -0 clang-format-22 --dry-run --Werror
macos:
runs-on: macos-latest
+2 -2
View File
@@ -82,8 +82,8 @@ jobs:
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
+4 -4
View File
@@ -444,7 +444,7 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex)
opts.offline = params.offline;
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
@@ -1615,7 +1615,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
[](common_params & params, const std::string & value) {
const auto sampler_names = string_split<std::string>(value, ';');
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
params.sampling.samplers = common_sampler_types_from_names(sampler_names);
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
}
).set_sampling());
@@ -2221,8 +2221,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
{"--image", "--audio", "--video"}, "FILE",
"path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) {
params.image.emplace_back(item);
+15 -4
View File
@@ -1625,8 +1625,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
const std::string THINK_END = "</think>";
const std::string GEN_PROMPT = "<|im_start|>assistant\n";
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
// Copy reasoning to the "thinking" field the template expects
auto adjusted_messages = json::array();
for (auto msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
msg["thinking"] = msg.at("reasoning_content");
}
adjusted_messages.push_back(msg);
}
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, adjusted_messages);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, adjusted_messages);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = { TOOL_CALL_START, TOOL_CALL_END, THINK_START, THINK_END };
@@ -1639,7 +1648,9 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
data.thinking_end_tag = THINK_END;
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
// Gate by reasoning format and whether the template supports <think>
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE &&
tmpl.source().find(THINK_START) != std::string::npos;
auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE;
if (inputs.has_continuation()) {
@@ -1658,7 +1669,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
auto end = p.end();
auto reasoning = p.eps();
if (extract_reasoning && inputs.enable_thinking) {
if (extract_reasoning) {
reasoning = p.optional(THINK_START + p.reasoning(p.until(THINK_END)) + THINK_END);
}
+1 -1
View File
@@ -1148,7 +1148,7 @@ static void common_init_sampler_from_model(
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
sparams.samplers = common_sampler_types_from_names(sampler_names);
}
}
}
+1 -1
View File
@@ -571,7 +571,7 @@ struct common_params {
struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)
std::vector<std::string> image; // path to image file(s) ; TODO: change the name to "media"
int image_min_tokens = -1;
int image_max_tokens = -1;
+49 -40
View File
@@ -769,54 +769,63 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
}
}
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
// since samplers names are written multiple ways
// make it ready for both system names and input names
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P },
};
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names) {
// sampler names can be written multiple ways; generate aliases from canonical names
static const auto sampler_name_map = []{
// canonical sampler name mapping
std::unordered_map<std::string, common_sampler_type> canonical_name_map {
{ "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
{ "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }
};
std::unordered_map<std::string, common_sampler_type> alias_name_map;
for (const auto & entry : canonical_name_map) {
const std::string & canonical = entry.first;
if (canonical.find('_') == std::string::npos) {
continue;
}
// kebab-case: "top-k", "min-p", etc.
{
std::string kebab_case = canonical;
std::replace(kebab_case.begin(), kebab_case.end(), '_', '-');
alias_name_map.insert({kebab_case, entry.second});
}
// no dash: "topk", "minp", etc.
{
std::string no_dash = canonical;
no_dash.erase(std::remove(no_dash.begin(), no_dash.end(), '_'), no_dash.end());
alias_name_map.insert({no_dash, entry.second});
}
}
// misc. aliases
alias_name_map.insert({"nucleus", COMMON_SAMPLER_TYPE_TOP_P});
alias_name_map.insert({"temp", COMMON_SAMPLER_TYPE_TEMPERATURE});
alias_name_map.insert({"typ", COMMON_SAMPLER_TYPE_TYPICAL_P});
// include aliases + canonical names in the complete mapping
alias_name_map.merge(canonical_name_map);
return alias_name_map;
}();
std::vector<common_sampler_type> samplers;
samplers.reserve(names.size());
for (const auto & name : names) {
auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) {
std::string name_lower = name;
std::transform(name_lower.begin(), name_lower.end(), name_lower.begin(), ::tolower);
auto sampler = sampler_name_map.find(name_lower);
if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
if (allow_alt_names) {
sampler = sampler_alt_name_map.find(name);
if (sampler != sampler_alt_name_map.end()) {
samplers.push_back(sampler->second);
continue;
}
}
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name_lower.c_str());
}
return samplers;
+1 -1
View File
@@ -109,7 +109,7 @@ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx,
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names);
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
+53 -44
View File
@@ -3,13 +3,14 @@
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include <algorithm>
#include <cassert>
#include <cstring>
@@ -58,10 +59,10 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
@@ -418,6 +419,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int32_t n_embd = 0;
bool is_mem_shared = false;
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
@@ -444,7 +447,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -490,6 +495,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_batch_beg.assign(n_seq, -1);
@@ -526,9 +533,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
@@ -571,48 +580,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const size_t row_bytes = (size_t) n_embd * sizeof(float);
common_batch_clear(batch);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!is_mem_shared) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -721,7 +724,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
}
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
}
+2
View File
@@ -75,9 +75,11 @@ TEXT_MODEL_MAP: dict[str, str] = {
"Gemma3TextModel": "gemma",
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4AssistantForCausalLM": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"Gemma4ForCausalLM": "gemma",
"Gemma4UnifiedForConditionalGeneration": "gemma",
"Gemma4UnifiedAssistantForCausalLM": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
"Glm4MoeForCausalLM": "glm",
+25 -4
View File
@@ -785,6 +785,26 @@ class Gemma4UnifiedModel(Gemma4Model):
self.gguf_writer.add_suppress_tokens(suppress_tokens)
@ModelBase.register("Gemma4AssistantForCausalLM", "Gemma4UnifiedAssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if "masked_embedding" in name:
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return None
return super().filter_tensors(item)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
self.gguf_writer.add_nextn_predict_layers(self.block_count)
@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True
@@ -812,10 +832,11 @@ class Gemma4VisionAudioModel(MmprojModel):
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-6))
# audio params
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
if self.has_audio_encoder:
assert self.hparams_audio is not None
self.gguf_writer.add_clip_audio_projector_type(gguf.VisionProjectorType.GEMMA4A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-6))
def is_audio_tensor(self, name: str) -> bool:
return "audio_tower" in name or "embed_audio" in name
+3 -2
View File
@@ -105,8 +105,9 @@ class MistralModel(LlamaModel):
gguf_writer.add_rope_scaling_yarn_log_mul(mscale_all_dim)
gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
if "llama_4_scaling" in hparams:
gguf_writer.add_attn_temperature_scale(hparams["llama_4_scaling"]["beta"])
llama_4_scaling = hparams.get("llama_4_scaling")
if llama_4_scaling is not None:
gguf_writer.add_attn_temperature_scale(llama_4_scaling["beta"])
class MistralMoeModel(DeepseekV2Model):
+1 -1
View File
@@ -238,7 +238,7 @@ def main() -> None:
assert hparams.get("vision_encoder") is not None, "This model does not support multimodal"
from conversion.pixtral import PixtralModel
model_class = PixtralModel
elif "moe" in hparams:
elif hparams.get("moe") is not None:
from conversion.mistral import MistralMoeModel
model_class = MistralMoeModel
else:
+2 -2
View File
@@ -4,8 +4,8 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 13)
set(GGML_VERSION_PATCH 1)
set(GGML_VERSION_MINOR 14)
set(GGML_VERSION_PATCH 0)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
+66 -9
View File
@@ -622,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
// cuda buffer
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::mutex device_mutex;
int active_count = 0;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
};
struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
@@ -639,6 +651,13 @@ struct ggml_backend_cuda_buffer_context {
static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete ctx;
}
@@ -791,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}
@@ -1490,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
}
static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
CUDA_CHECK(cudaFreeHost(buffer->context));
}
@@ -1498,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
return nullptr;
}
ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.
void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
@@ -1523,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return buffer;
}
@@ -3140,6 +3179,12 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
static void ggml_backend_cuda_free(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
delete cuda_ctx;
delete backend;
}
@@ -4871,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
// backend device
struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->name.c_str();
@@ -4967,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::lock_guard<std::mutex> lock(ctx->device_mutex);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemGetInfo(free, total));
@@ -4993,6 +5035,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
}
#endif // defined(__linux__)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
// context that permanently consumes VRAM. Reset the device to free it.
if (ctx->active_count == 0) {
CUDA_CHECK(cudaDeviceReset());
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
@@ -5687,13 +5736,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
return nullptr;
}
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);
ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .device = */ dev,
/* .context = */ ctx,
};
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
return cuda_backend;
}
+2 -2
View File
@@ -219,9 +219,9 @@
#define RDNA3
#endif // defined(__GFX11__)
#if defined(__gfx1150__) || defined(__gfx1151__)
#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
#define RDNA3_5
#endif // defined(__gfx1150__) || defined(__gfx1151__)
#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__)
#if defined(RDNA3) && !defined(RDNA3_5)
#define RDNA3_0
+5 -1
View File
@@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
char base[256];
char name[256];
if (ne00*ne01 <= 1024) {
if (KH*KW <= 1024) {
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
} else {
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));
+129 -20
View File
@@ -1976,6 +1976,9 @@ struct ggml_backend_vk_context {
// Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
vk_pipeline_struct * prealloc_y_last_pipeline_used {};
const ggml_tensor * prealloc_y_last_tensor_used {};
// True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback.
// If false, then it's contiguous.
bool prealloc_y_last_decode_vector_staging {};
// Track which nodes have been used since the last sync, and whether they were written to
std::vector<const ggml_tensor *> unsynced_nodes_written;
@@ -3652,9 +3655,10 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u;
l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -8110,6 +8114,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_sync_buffers(ctx, subctx);
}
// Copy/convert tensor into a caller-defined dense layout. Destination strides
// are in output elements, not bytes.
static void ggml_vk_cpy_to_strided(
ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor,
const vk_subbuffer & in, const vk_subbuffer & out,
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) {
VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
const int tensor_type_size = ggml_type_size(tensor->type);
const uint32_t ne = ggml_nelements(tensor);
std::array<uint32_t, 3> elements;
if (ne > 262144) {
elements = { 512, 512, CEIL_DIV(ne, 262144) };
} else if (ne > 512) {
elements = { 512, CEIL_DIV(ne, 512), 1 };
} else {
elements = { ne, 1, 1 };
}
vk_op_unary_push_constants pc = {
(uint32_t)ne,
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13,
0,
0.0f, 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
init_pushconst_fastdiv(pc);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
}
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
switch(type) {
case GGML_TYPE_Q8_1:
@@ -8367,24 +8405,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -8642,24 +8684,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -9110,12 +9156,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
// If src0 is BF16, try to use a BF16 x BF16 multiply
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
// B must already be, or be convertible to, the matmul B type used by this path.
const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector &&
(f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) &&
(src1->type == GGML_TYPE_F32 || src1->type == f16_type);
// If B is copied to prealloc_y, we can choose a 4-element-aligned row stride.
const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type;
// Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned.
const bool y_decode_vector_aligned =
(ne10 % 4 == 0) &&
(y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0);
// Stage B only when decode-vector is available and direct B reads would be misaligned.
const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned;
#else
const bool y_decode_vector_staging = false;
#endif
const bool y_non_contig = y_decode_vector_staging ||
(ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
!ggml_vk_dim01_contiguous(src1);
// If src0 is BF16, try to use a BF16 x BF16 multiply
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10;
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
@@ -9154,11 +9218,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13;
const uint64_t d_ne = ggml_nelements(dst);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
@@ -9168,13 +9232,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
auto make_y_staged_dst = [&]() {
ggml_tensor y_staged_dst = *src1;
y_staged_dst.type = f16_type;
y_staged_dst.nb[0] = ggml_type_size(f16_type);
y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride;
y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n;
y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2];
return y_staged_dst;
};
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
} else {
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
}
if (y_non_contig) {
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
ggml_tensor y_staged_dst;
const ggml_tensor * y_staged_dst_ptr = nullptr;
if (y_decode_vector_staging) {
y_staged_dst = make_y_staged_dst();
y_staged_dst_ptr = &y_staged_dst;
}
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type);
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
@@ -9292,30 +9373,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
if (y_non_contig) {
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
if (y_decode_vector_staging) {
const ggml_tensor y_staged_dst = make_y_staged_dst();
const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type);
ggml_vk_cpy_to_strided(
ctx, subctx, to_fp16_vk_1, src1,
ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0),
(uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size),
(uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size));
} else {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
}
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
ggml_vk_sync_buffers(ctx, subctx);
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10;
uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11;
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@@ -9330,7 +9428,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
{ d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
ne01, ne21, ne10, ne10, ne10, ne01,
ne01, ne21, ne10, ne10, stride_b_y, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
); // NOLINT
@@ -9488,24 +9586,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
ctx->prealloc_y_last_tensor_used != src1 ||
ctx->prealloc_y_last_decode_vector_staging) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
ctx->prealloc_y_last_decode_vector_staging = false;
}
}
@@ -13730,7 +13832,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
ggml_vk_destroy_buffer(ctx->prealloc_y);
}
ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
}
if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -14310,6 +14414,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
@@ -14360,6 +14466,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ggml_vk_destroy_buffer(ctx->sync_staging);
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0;
@@ -15539,6 +15647,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_tensor_used = nullptr;
ctx->prealloc_y_last_decode_vector_staging = false;
if (ctx->prealloc_size_add_rms_partials) {
ggml_vk_preallocate_buffers(ctx, nullptr);
@@ -11,6 +11,9 @@
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR
#extension GL_NV_cooperative_matrix_decode_vector : enable
#endif
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
@@ -69,10 +72,13 @@ layout (push_constant) uniform parameter
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];};
#endif
#if QUANT_K > 1
#include "dequant_funcs_cm2.glsl"
#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector)
#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
#define DECODEFUNCA , dequantFuncA, dequantFuncA_v
#else
#define DECODEFUNCA , dequantFuncA
@@ -113,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
const uint row_i = blockCoords[0];
const u16vec4 row_idx = row_ids[row_i];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
// The decode-vector path gives B a K-dimension tensor-layout block size of BK.
const uint k = blockCoords[1] * BK + coordInBlock[1];
#else
const uint k = blockCoords[1];
#endif
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k];
return ret;
}
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint row_i = blockCoords[0];
const u16vec4 row_idx = row_ids[row_i];
const uint k = blockCoords[1] * BK + coordInBlock[1];
const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k;
return data_b_v4[base >> 2];
}
#define DECODEFUNCB , decodeFuncB, decodeFuncB_v
#else
#define DECODEFUNCB , decodeFuncB
#endif
D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
{
uint dr = ir * BM + r;
@@ -287,6 +315,9 @@ void main() {
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
#endif
#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR)
tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK);
#endif
// Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k.
@@ -499,7 +530,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -507,7 +538,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -543,7 +574,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
@@ -551,7 +582,7 @@ void main() {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
@@ -588,7 +619,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@@ -600,7 +631,7 @@ void main() {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
@@ -457,6 +457,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (coopmat) {
base_dict["COOPMAT"] = "1";
}
#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
if (coopmat2) {
base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1";
}
#endif
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
@@ -523,11 +528,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -548,8 +553,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (!(coopmat || coopmat2))
#endif
{
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
@@ -579,13 +584,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+25 -11
View File
@@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash {
/** Concat **/
struct ggml_webgpu_concat_pipeline_key {
int type;
int type;
bool src_overlap;
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
return type == other.type && src_overlap == other.src_overlap;
}
};
struct ggml_webgpu_concat_pipeline_key_hash {
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
@@ -640,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
const uint32_t offset_elems =
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
ggml_type_size(K->type));
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
}
@@ -651,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
}
inline bool ggml_webgpu_flash_attn_kv_direct(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V,
uint32_t kv_direct_align) {
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
}
@@ -667,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
return key;
}
@@ -1723,7 +1730,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
@@ -2634,6 +2641,7 @@ class ggml_webgpu_shader_lib {
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_concat_pipeline_key key = {};
key.type = context.dst->type;
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
auto it = concat_pipelines.find(key);
if (it != concat_pipelines.end()) {
@@ -2656,11 +2664,17 @@ class ggml_webgpu_shader_lib {
GGML_ABORT("Unsupported type for concat shader");
}
if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_concat, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
decisions->wg_size = context.max_wg_size;
decisions->src_overlap = key.src_overlap;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
concat_pipelines[key] = pipeline;
+69 -44
View File
@@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
uint32_t value,
size_t offset,
size_t size) {
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
size_t bytes_per_wg =
ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
@@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
shader_lib_ctx.src0 = src;
shader_lib_ctx.src1 = nullptr;
shader_lib_ctx.dst = dst;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
@@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
@@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
}
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
@@ -2305,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
uint32_t ne = (uint32_t) ggml_nelements(dst);
uint32_t dim = (uint32_t) dst->op_params[0];
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim]
};
std::vector<wgpu::BindGroupEntry> entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
};
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
@@ -2339,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
size_t merged_offset = 0;
size_t merged_size = 0;
if (decisions->src_overlap) {
const ggml_webgpu_merged_binding_range merged_range =
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
merged_offset = merged_range.offset;
merged_size = merged_range.size;
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
}
std::vector<uint32_t> params = { ne,
offset_src0,
offset_src1,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
dim,
(uint32_t) src0->ne[dim] };
std::vector<wgpu::BindGroupEntry> entries = {};
if (decisions->src_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
}
@@ -2673,8 +2695,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
uint32_t wg_x, wg_y;
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
}
static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -3751,7 +3775,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
// we use the maximum workgroup size for the memset pipeline
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup *
ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
// Size the bytes_per_thread so that the largest buffer size can be handled
ctx->capabilities.memset_bytes_per_thread =
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
@@ -4228,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
const uint32_t q_tile =
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
@@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
let src0_i = params.offset_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i);
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i < params.ne) {
let src0_i = params.offset_src0 + src0_index(i);
let src1_i = params.offset_src1 + src1_index(i);
update(params.offset_dst + i, src0_i, src1_i);
}
}
+19 -1
View File
@@ -31,6 +31,16 @@ struct Params {
#define DataType i32
#endif
#ifdef SRC_OVERLAP
@group(0) @binding(0)
var<storage, read_write> merged_src: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
@@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
ni[1] * params.stride_src0_1 +
ni[2] * params.stride_src0_2 +
ni[3] * params.stride_src0_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i];
#else
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
#endif
} else {
ni[params.dim] -= params.src0_nedim;
let src_i = ni[0] * params.stride_src1_0 +
ni[1] * params.stride_src1_1 +
ni[2] * params.stride_src1_2 +
ni[3] * params.stride_src1_3;
#ifdef SRC_OVERLAP
dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i];
#else
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
#endif
}
}
}
@@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // INIT_SRC0_SHMEM_Q1_0
#ifdef INIT_SRC0_SHMEM_Q4_0
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1)
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
#else
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
#endif
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let tile_m = block_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let block_k = block_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#ifdef INIT_SRC0_SHMEM_Q4_0
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_0
#elif INIT_SRC0_SHMEM_Q4_1
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q5_0
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#elif INIT_SRC0_SHMEM_Q5_1
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
#ifdef INIT_SRC0_SHMEM_Q5_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u);
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
let q_packed = load_u32_at_src0_aligned(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
@@ -277,461 +201,306 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q8_0
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#elif INIT_SRC0_SHMEM_Q8_1
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
}
}
#elif INIT_SRC0_SHMEM_MXFP4
let block_byte_base = src0_idx * 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
#endif
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#endif
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
const BLOCK_SIZE = 256u;
const NQ = 4u;
fn store_shmem_kquants(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
}
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 84u;
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
let scales_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 16u;
let dm_byte_base = block_byte_base + 80u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(d_packed[0]);
let dmin = f16(d_packed[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
// whole 2 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u),
);
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block);
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let dl = d * f16(scale & 0xFu);
let ml = dmin * f16(scale >> 4u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
#elif INIT_SRC0_SHMEM_Q3_K
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
let hmask_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 32u;
let scales_byte_base = block_byte_base + 96u;
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let d_all = load_f16_at_src0(block_byte_base + 108u);
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
let is = k_in_block / 16u;
let hmask_block = pos_in_chunk;
let hmask_shift_phase = k_in_block / 32u;
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
// low 2 bits (4 elems)
let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block);
let q_lo2_vec4 = vec4<f16>(
f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u)
);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
// high 1 bit (4 elems)
let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk);
let q_hi1_vec4 = vec4<f16>(
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u))
);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_vec4 = q_lo2_vec4 - q_hi1_vec4;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu;
let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u;
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 110u;
store_shmem_kquants(dl * q_vec4, elem_idx);
#elif INIT_SRC0_SHMEM_Q4_K
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qs_byte_base = block_byte_base + 16u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// whole 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q5_K
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qh_byte_base = block_byte_base + 16u;
let qs_byte_base = block_byte_base + 48u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 176u;
let qh_block = k_in_block % 32u;
let qh_shift_phase = sub_block;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
// low 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_lo4_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
// high 1 bit (4 elems)
let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block);
let qh_vec4 = vec4<f16>(
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u))
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q6_K
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
let ql_byte_base = block_byte_base;
let qh_byte_base = block_byte_base + 128u;
let scales_byte_base = block_byte_base + 192u;
let d_byte_base = block_byte_base + 208u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let d = load_f16_at_src0(d_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let chunk = k_in_block / 128u;
let ql_pos_in_chunk = (k_in_block % 128u) % 64u;
let qh_pos_in_chunk = (k_in_block % 128u) % 32u;
let sub_block = k_in_block / 16u;
let ql_shift_phase = (k_in_block % 128u) / 64u;
let qh_shift_phase = (k_in_block % 128u) / 32u;
let qh_byte = get_byte(qh_packed, l % 4u);
// low 4 bits (4 elems)
let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk);
let ql_lo4_vec4 = vec4<u32>(
(ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu
);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
// hi 2 bits (4 elems)
let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk);
let qh_hi2_vec4 = vec4<u32>(
((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u,
);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0);
let scale_byte = scales_byte_base + 1u * sub_block;
let scale_word = load_u32_at_src0_aligned(scale_byte);
let scale = get_byte_i32(scale_word, scale_byte & 3u);
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
#endif
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K
#endif // k-quants
#ifdef INIT_SRC0_SHMEM_IQ4_NL
const BLOCK_SIZE = 32u;
@@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
#endif // INIT_SRC0_SHMEM_IQ3_S
#ifdef INIT_SRC0_SHMEM_MXFP4
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 17u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
let e = ldexp(1.0, i32(eu8) - 128);
// store NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_MXFP4
+6 -4
View File
@@ -43,12 +43,14 @@ struct Params {
var<storage, read_write> src: array<f32>;
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(
@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
var i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (i >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
+7 -4
View File
@@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE {
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
@builtin(num_workgroups) num_wg: vec3<u32>) {
let threads_per_group = u32(WG_SIZE);
let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y;
if (flat_i >= params.ne) {
return;
}
var i = gid.x;
var i = flat_i;
let ne2 = params.ne2;
#ifdef DIAG
let ne1 = params.ne0;
@@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
#ifdef INPLACE
src[params.offset_src + src_idx] = res;
#else
dst[params.offset_dst + gid.x] = res;
dst[params.offset_dst + flat_i] = res;
#endif
}
+30
View File
@@ -440,6 +440,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA4 = auto()
GEMMA4_ASSISTANT = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
@@ -537,6 +538,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
TOKEN_EMBD_NORM = auto()
MASKED_EMBD_CENTROIDS= auto()
MASKED_EMBD_ORDERING = auto()
TOKEN_TYPES = auto()
POS_EMBD = auto()
OUTPUT = auto()
@@ -897,6 +900,8 @@ class MODEL_TENSOR(IntEnum):
A_PER_DIM_K_SCALE = auto() # gemma4
A_PER_DIM_SCALE = auto() # gemma4
# nextn/mtp
NEXTN_PROJ_PRE = auto()
NEXTN_PROJ_POST = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
@@ -986,6 +991,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
@@ -1083,6 +1089,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
MODEL_TENSOR.TOKEN_TYPES: "token_types",
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: "masked_embd_centroids",
MODEL_TENSOR.MASKED_EMBD_ORDERING: "masked_embd_ordering",
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
@@ -1471,6 +1479,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
# NextN/MTP
MODEL_TENSOR.NEXTN_PROJ_PRE: "nextn.pre_projection",
MODEL_TENSOR.NEXTN_PROJ_POST: "nextn.post_projection",
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
@@ -2577,6 +2587,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_ASSISTANT: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.MASKED_EMBD_CENTROIDS,
MODEL_TENSOR.MASKED_EMBD_ORDERING,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.NEXTN_PROJ_PRE,
MODEL_TENSOR.NEXTN_PROJ_POST,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
+16
View File
@@ -37,6 +37,14 @@ class TensorNameMap:
"model.embed", # talkie
),
# Masked embeddings
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: (
"masked_embedding.centroids", # gemma-4 E2B/E4B assistants
),
MODEL_TENSOR.MASKED_EMBD_ORDERING: (
"masked_embedding.token_ordering", # gemma-4 E2B/E4B assistants
),
# Token type embeddings
MODEL_TENSOR.TOKEN_TYPES: (
"embeddings.token_type_embeddings", # bert nomic-bert
@@ -2367,6 +2375,14 @@ class TensorNameMap:
),
# NextN/MTP tensors
MODEL_TENSOR.NEXTN_PROJ_PRE: (
"pre_projection",
),
MODEL_TENSOR.NEXTN_PROJ_POST: (
"post_projection",
),
MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),
+4
View File
@@ -388,6 +388,10 @@ extern "C" {
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
// a source/target/parent context
// can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts
struct llama_context * ctx_other;
};
struct llama_model_tensor_override {
+115
View File
@@ -0,0 +1,115 @@
{{- bos_token -}}
{%- set preserve_thinking = preserve_thinking | default(false) -%}
{%- macro format_arg_value(arg_value) -%}
{%- if arg_value is string -%}
{{- "'" + arg_value + "'" -}}
{%- elif arg_value is mapping -%}
{{- arg_value | tojson -}}
{%- else -%}
{{- arg_value | string -}}
{%- endif -%}
{%- endmacro -%}
{%- macro parse_content(content) -%}
{%- if content is string -%}
{{- content -}}
{%- else -%}
{%- set _ns = namespace(result="") -%}
{%- for item in content -%}
{%- if item["type"] == "image" -%}
{%- set _ns.result = _ns.result + "<image>" -%}
{%- elif item["type"] == "text" -%}
{%- set _ns.result = _ns.result + item["text"] -%}
{%- else -%}
{%- set _ns.result = _ns.result + item | tojson -%}
{%- endif -%}
{%- endfor -%}
{{- _ns.result -}}
{%- endif -%}
{%- endmacro -%}
{%- macro render_tool_calls(tool_calls) -%}
{%- set tool_calls_ns = namespace(tool_calls=[]) -%}
{%- for tool_call in tool_calls -%}
{%- set func_name = tool_call["function"]["name"] -%}
{%- set func_args = tool_call["function"]["arguments"] -%}
{%- set args_ns = namespace(arg_strings=[]) -%}
{%- for arg_name, arg_value in func_args.items() -%}
{%- set args_ns.arg_strings = args_ns.arg_strings + [arg_name + "=" + format_arg_value(arg_value)] -%}
{%- endfor -%}
{%- set tool_calls_ns.tool_calls = tool_calls_ns.tool_calls + [func_name + "(" + (args_ns.arg_strings | join(", ")) + ")"] -%}
{%- endfor -%}
{{- "<|tool_call_start|>[" + (tool_calls_ns.tool_calls | join(", ")) + "]<|tool_call_end|>" -}}
{%- endmacro -%}
{%- set ns = namespace(system_prompt="", last_user_index=-1) -%}
{%- if messages[0]["role"] == "system" -%}
{%- if messages[0].get("content") -%}
{%- set ns.system_prompt = parse_content(messages[0]["content"]) -%}
{%- endif -%}
{%- set messages = messages[1:] -%}
{%- endif -%}
{%- if tools -%}
{%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: [" -%}
{%- for tool in tools -%}
{%- if tool is not string -%}
{%- set tool = tool | tojson -%}
{%- endif -%}
{%- set ns.system_prompt = ns.system_prompt + tool -%}
{%- if not loop.last -%}
{%- set ns.system_prompt = ns.system_prompt + ", " -%}
{%- endif -%}
{%- endfor -%}
{%- set ns.system_prompt = ns.system_prompt + "]" -%}
{%- endif -%}
{%- if ns.system_prompt -%}
{{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}}
{%- endif -%}
{%- for message in messages -%}
{%- if message["role"] == "user" -%}
{%- set ns.last_user_index = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" -}}
{%- if message.role == "assistant" -%}
{%- generation -%}
{%- if message.thinking is defined and (preserve_thinking or loop.index0 > ns.last_user_index) -%}
{{- "<think>" + message.thinking + "</think>" -}}
{%- endif -%}
{%- set _cfm_tag = "CONTINUE_FINAL_MESSAGE_TAG " -%}
{%- set _has_cfm = false -%}
{%- if message.content is defined -%}
{%- set content = parse_content(message.content) -%}
{%- if not (preserve_thinking or loop.index0 > ns.last_user_index) -%}
{%- if "</think>" in content -%}
{%- set content = content.split("</think>")[-1] | trim -%}
{%- endif -%}
{%- endif -%}
{%- if message.tool_calls is defined and content.endswith(_cfm_tag) -%}
{%- set _has_cfm = true -%}
{%- set _trunc_len = (content | length) - (_cfm_tag | length) -%}
{{- content[:_trunc_len] -}}
{%- else -%}
{{- content -}}
{%- endif -%}
{%- endif -%}
{%- if message.tool_calls is defined -%}
{{- render_tool_calls(message.tool_calls) -}}
{%- endif -%}
{%- if _has_cfm -%}
{{- _cfm_tag -}}
{%- endif -%}
{{- "<|im_end|>\n" -}}
{%- endgeneration -%}
{%- else %}
{%- if message.get("content") -%}
{{- parse_content(message["content"]) -}}
{%- endif -%}
{{- "<|im_end|>\n" -}}
{%- endif %}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
+1 -1
View File
@@ -1 +1 @@
1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3
7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1
+9
View File
@@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
@@ -453,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
@@ -556,6 +559,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
};
// declare information about the model weight tensors:
@@ -765,6 +770,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.
@@ -778,6 +785,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
{LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
};
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
+6
View File
@@ -61,6 +61,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
@@ -557,14 +558,19 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
LLM_TENSOR_NEXTN_HNORM,
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
LLM_TENSOR_MASKED_EMBD_CENTROIDS,
LLM_TENSOR_MASKED_EMBD_ORDERING,
};
enum llm_tensor_layer {
LLM_TENSOR_LAYER_INPUT,
LLM_TENSOR_LAYER_REPEATING,
+37 -18
View File
@@ -69,9 +69,10 @@ llama_context::llama_context(
cparams.embeddings_nextn_masked = false;
cparams.offload_kqv = params.offload_kqv;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;
cparams.ctx_type = params.ctx_type;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@@ -84,7 +85,17 @@ llama_context::llama_context(
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.ctx_type = params.ctx_type;
cparams.ctx_other = nullptr;
// TODO: more generic
if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) {
if (params.ctx_other == nullptr) {
// TODO: change from runtime_error to llama_exception to avoid printing error message
throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
// Initialize backend samplers here so they are part of the sampling graph
// before the reserve passes run later in this function. This avoids a later
@@ -300,10 +311,11 @@ llama_context::llama_context(
// init the memory module
if (!hparams.vocab_only) {
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type= */ cparams.ctx_type,
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
/*.swa_full =*/ params.swa_full,
/*.ctx_type =*/ cparams.ctx_type,
/*.mem_other =*/ llama_get_memory(cparams.ctx_other),
};
memory.reset(model.create_memory(params_mem, cparams));
@@ -904,7 +916,7 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) {
throw std::runtime_error("no nextn embeddings");
}
const uint32_t n_embd = model.hparams.n_embd;
const uint32_t n_embd = model.hparams.n_embd_out();
if (!cparams.embeddings_nextn_masked) {
// unmasked: nextn rows are stored densely, indexed by raw token position.
@@ -1473,7 +1485,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size);
ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float));
}
@@ -1924,7 +1936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn);
GGML_ASSERT(backend_h != nullptr);
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_embd = hparams.n_embd_out();
float * embd_nextn_out = embd_nextn.data + offset*n_embd;
GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size);
@@ -2017,7 +2029,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
const auto n_embd_out = hparams.n_embd_out();
bool has_logits = true;
@@ -2036,12 +2047,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd*n_outputs_max : 0;
embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0;
if (has_embd_nextn && !cparams.embeddings_nextn_masked) {
// unmasked: nextn row exists for every token in the batch, not just
// those flagged via batch.logits[i] -> size by token count instead.
embd_nextn.size = (size_t) n_embd * n_batch;
embd_nextn.size = (size_t) n_embd_out * n_batch;
}
// Allocate backend sampling output buffers if there are backend samplers configured.
@@ -3375,6 +3386,7 @@ llama_context_params llama_context_default_params() {
/*.kv_unified =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
/*.ctx_other =*/ nullptr,
};
return result;
@@ -3454,7 +3466,6 @@ llama_context * llama_init_from_model(
return nullptr;
}
try {
auto * ctx = new llama_context(*model, params);
return ctx;
@@ -3593,6 +3604,14 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) {
ctx->set_embeddings_nextn(value, masked);
}
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
if (!ctx) {
return nullptr;
}
return ctx->get_memory();
}
float * llama_get_embeddings_nextn(llama_context * ctx) {
ctx->synchronize();
@@ -3656,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve(
uint32_t n_tokens,
uint32_t n_seqs,
uint32_t n_outputs) {
auto * memory = ctx->get_memory();
auto memory = ctx->get_memory();
llama_memory_context_ptr mctx;
if (memory) {
mctx = memory->init_full();
@@ -3696,10 +3715,6 @@ int32_t llama_set_adapter_cvec(
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem, bool data) {
if (!mem) {
return;
@@ -4010,3 +4025,7 @@ void llama_opt_epoch(
llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) {
return ctx->memory_breakdown();
}
llama_context * llama_get_ctx_other(struct llama_context * ctx) {
return ctx->get_cparams().ctx_other;
}
+2 -1
View File
@@ -6,6 +6,7 @@
#include "llama-graph.h"
#include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include "ggml-opt.h"
@@ -273,7 +274,7 @@ private:
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_memory_i> memory;
llama_memory_ptr memory;
// decode output (2-dimensional array: [n_outputs][n_vocab])
buffer_view<float> logits = {nullptr, 0};
+2
View File
@@ -49,4 +49,6 @@ struct llama_cparams {
ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
llama_context * ctx_other;
};
+2
View File
@@ -100,3 +100,5 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);
// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i);
LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx);
+19 -5
View File
@@ -397,7 +397,7 @@ static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n
case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break;
};
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
@@ -565,7 +565,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
if (self_kq_mask && self_kq_mask->buffer) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -573,7 +576,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
@@ -605,7 +610,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs && self_k_idxs->buffer) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
if (self_kq_mask && self_kq_mask->buffer) {
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
}
@@ -613,7 +620,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
}
@@ -756,7 +765,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
}
if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) {
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
@@ -764,7 +775,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
}
if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
@@ -810,18 +823,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
}
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
@@ -1006,6 +1019,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
ubatch (params.ubatch),
n_embd (hparams.n_embd),
n_layer (hparams.n_layer()),
n_layer_nextn (hparams.n_layer_nextn),
n_rot (hparams.n_rot()),
n_ctx (cparams.n_ctx),
n_head (hparams.n_head()),
+1
View File
@@ -784,6 +784,7 @@ struct llm_graph_context {
const int64_t n_embd;
const int64_t n_layer;
const int64_t n_layer_nextn;
const int64_t n_rot;
const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train)
const int64_t n_head;
+4
View File
@@ -91,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const {
}
uint32_t llama_hparams::n_embd_inp() const {
if (n_embd_inp_impl > 0) {
return n_embd_inp_impl;
}
uint32_t n_embd_inp = n_embd;
if (n_deepstack_layers > 0) {
+4
View File
@@ -185,6 +185,9 @@ struct llama_hparams {
// for Classifiers
uint32_t n_cls_out = 1;
// input embedding dimension (0 = use n_embd)
uint32_t n_embd_inp_impl = 0;
// output embedding dimension (0 = use n_embd)
uint32_t n_embd_out_impl = 0;
@@ -224,6 +227,7 @@ struct llama_hparams {
// complex mapping. If using deepstack_mapping_arr, also make sure to set
// n_deepstack_layers to the number of unique deepstack layers so that
// n_embd_imp is accurate (see granite.cpp).
// TODO: can be expressed via the `new n_embd_inp_impl` and remove this param
uint32_t n_deepstack_layers = 0;
// deepstack layer array (Granite4 Vision)
+2 -2
View File
@@ -32,7 +32,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_mla = std::make_unique<llama_kv_cache>(
model, model.hparams, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
// we use llama_kv_cache for caching indexer keys
// by hand-tweaking some hparams we fool it to create
@@ -49,7 +49,7 @@ llama_kv_cache_dsa::llama_kv_cache_dsa(
kv_lid = std::make_unique<llama_kv_cache>(
model, hparams_lid, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
n_swa, swa_type, nullptr, filter, reuse, nullptr);
}
void llama_kv_cache_dsa::clear(bool data) {
+15 -3
View File
@@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
@@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
llama_memory_t mem_other_base = nullptr;
if (mem_other) {
mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base();
}
llama_memory_t mem_other_swa = nullptr;
if (mem_other) {
mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa();
}
kv_base = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, hparams, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share);
}
void llama_kv_cache_iswa::clear(bool data) {
+3 -1
View File
@@ -25,8 +25,10 @@ public:
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;
+124 -23
View File
@@ -90,10 +90,26 @@ llama_kv_cache::llama_kv_cache(
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
const layer_reuse_cb & reuse,
const layer_share_cb & share) :
model(model), hparams(hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type),
other(static_cast<llama_kv_cache *>(mem_other)),
v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()),
v_cells(*v_cells_impl) {
// shared cells view the source cache's K/V tensors, so the cell count
// follows the source allocation: a fitted target can be smaller than the
// draft default and oversized views would overflow the source tensors
if (other) {
const uint32_t size_other = other->get_size();
if (kv_size != size_other) {
LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other);
kv_size = size_other;
}
}
GGML_ASSERT(kv_size % n_pad == 0);
@@ -171,6 +187,24 @@ llama_kv_cache::llama_kv_cache(
continue;
}
if (share && other) {
const int32_t il_share = share(il);
if (il_share >= 0) {
const auto & layer_share = other->layers[other->map_layer_ids[il_share]];
LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share,
layer_share.k->data, layer_share.v->data);
map_layer_ids[il] = layers.size();
layers.push_back(layer_share);
layers.back().il = il;
continue;
}
}
if (n_embd_head_k_all == 0) {
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
@@ -282,29 +316,38 @@ llama_kv_cache::llama_kv_cache(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
n_embd_head_k_all = other->n_embd_head_k_all;
n_embd_head_v_all = other->n_embd_head_v_all;
attn_rot_k = other->attn_rot_k;
attn_rot_v = other->attn_rot_v;
} else {
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
}
attn_rot_k =
!attn_rot_disable &&
n_embd_head_k_all > 0 &&
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
attn_rot_v =
!attn_rot_disable &&
n_embd_head_v_all > 0 &&
ggml_is_quantized(type_v) &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
@@ -347,6 +390,11 @@ void llama_kv_cache::clear(bool data) {
}
bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
if (p0 < 0) {
@@ -410,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
}
void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
@@ -497,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
}
void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -519,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
}
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
@@ -564,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
}
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
@@ -598,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in
}
llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_min(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -606,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
}
llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return other->seq_pos_max(seq_id);
}
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -746,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_
}
bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return true;
}
bool updated = false;
auto * sched = lctx->get_sched();
@@ -1021,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
}
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -1815,6 +1903,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
GGML_ASSERT(!other);
auto * ctx = res->get_ctx();
auto * gf = res->get_gf();
@@ -1860,6 +1951,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
}
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
io.write(&n_stream, sizeof(n_stream));
@@ -1925,6 +2021,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
}
void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
// TODO: refactor [TAG_KV_CACHE_SHARE_CELLS]
if (other) {
return;
}
GGML_UNUSED(flags);
GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
+10 -3
View File
@@ -98,7 +98,7 @@ public:
// likely through `struct llama_memory_params`
llama_kv_cache(
const llama_model & model,
const llama_hparams & hparams,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
@@ -109,8 +109,10 @@ public:
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache() = default;
@@ -264,7 +266,12 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells> v_cells;
// TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS]
llama_kv_cache * other;
std::shared_ptr<llama_kv_cells_vec> v_cells_impl;
llama_kv_cells_vec & v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
+2
View File
@@ -531,3 +531,5 @@ private:
}
}
};
using llama_kv_cells_vec = std::vector<llama_kv_cells>;
+2
View File
@@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
n_seq_max,
n_ubatch,
n_pad,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(
+2
View File
@@ -44,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid(
n_pad,
n_swa,
swa_type,
nullptr,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recr(il); }
: filter_attn,
nullptr,
nullptr
)),
mem_recr(new llama_memory_recurrent(
+4
View File
@@ -23,6 +23,8 @@ struct llama_memory_params {
bool swa_full;
llama_context_type ctx_type;
llama_memory_t mem_other;
};
enum llama_memory_status {
@@ -76,6 +78,8 @@ struct llama_memory_i {
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
using layer_share_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
+64 -23
View File
@@ -139,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_gemma3n(params);
case LLM_ARCH_GEMMA4:
return new llama_model_gemma4(params);
case LLM_ARCH_GEMMA4_ASSISTANT:
return new llama_model_gemma4_assistant(params);
case LLM_ARCH_GEMMA_EMBEDDING:
return new llama_model_gemma_embedding(params);
case LLM_ARCH_STARCODER2:
@@ -1717,19 +1719,21 @@ void llama_model::print_info() const {
if (!hparams.vocab_only) {
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp());
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out());
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer());
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all);
LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any());
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
@@ -1737,7 +1741,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer()).c_str());
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
@@ -1764,7 +1768,7 @@ void llama_model::print_info() const {
[](const auto & entry) { return entry >= 0; })) {
LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__,
print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; },
hparams.n_layer()).c_str());
hparams.n_layer_all).c_str());
}
// MRoPE (Multi-axis Rotary Position Embedding) sections
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
@@ -2113,8 +2117,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_recr */ std::move(filter_recr));
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_filter_cb filter = nullptr;
llama_memory_i::layer_reuse_cb reuse = nullptr;
llama_kv_cache::layer_share_cb share = nullptr;
if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) {
reuse = [&](uint32_t il) {
@@ -2143,20 +2148,53 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
llama_memory_t mem_other = llama_get_memory(cparams.ctx_other);
share = [&](int32_t il) {
const llama_model * model_other = llama_get_model(cparams.ctx_other);
if (hparams.is_swa(il)) {
return llama_model_n_layer(model_other) - 2;
}
return llama_model_n_layer(model_other) - 1;
};
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
mem_other,
filter,
reuse,
share);
} else {
res = new llama_kv_cache_iswa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
nullptr,
filter,
reuse,
share);
}
} else {
GGML_ASSERT(!hparams.is_swa_any());
@@ -2173,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
filter,
nullptr,
nullptr);
}
}
@@ -2406,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_GEMMA4:
case LLM_ARCH_GEMMA4_ASSISTANT:
case LLM_ARCH_GEMMA_EMBEDDING:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
+5
View File
@@ -548,6 +548,10 @@ struct llama_model {
struct ggml_tensor * output_s = nullptr;
struct ggml_tensor * output_in_s = nullptr;
// NextN/MTP model-level projections
struct ggml_tensor * nextn_proj_pre = nullptr;
struct ggml_tensor * nextn_proj_post = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
@@ -702,6 +706,7 @@ const char * llm_type_name(llm_type type);
#define LLAMA_LOAD_LOCALS \
const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \
const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \
const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
+203
View File
@@ -0,0 +1,203 @@
#include "models.h"
void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) {
hparams.n_embd_inp_impl = hparams.n_embd_out();
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
uint32_t n_kv_shared_layers = 0;
ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false);
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false);
GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl");
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false);
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa);
}
void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
if (n_embd_head_k != n_embd_head_v) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v");
}
if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) {
throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa");
}
if (hparams.n_embd_out() == n_embd) {
throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size");
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED);
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED);
const int64_t n_embd_backbone = hparams.n_embd_inp();
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
int rope_freqs_flag = 0;
for (int i = 0; i < n_layer_nextn; ++i) {
auto & layer = layers[i];
const int64_t n_head = hparams.n_head(i);
const int64_t n_embd_head = hparams.n_embd_head_k(i);
const int64_t n_ff = hparams.n_ff(i);
if (i == 0) {
nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0);
}
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0);
if (!hparams.is_swa(i)) {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag);
rope_freqs_flag = TENSOR_DUPLICATED;
}
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const {
return std::make_unique<graph>(*this, params);
}
llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const int64_t n_embd_backbone = hparams.n_embd_inp();
ggml_tensor * inp_tokens;
ggml_tensor * inp_h;
{
auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
cb(inp->tokens, "inp_tokens", -1);
ggml_set_input(inp->tokens);
inp_tokens = inp->tokens;
res->t_inp_tokens = inp->tokens;
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens);
cb(inp->embd, "inp_h", -1);
ggml_set_input(inp->embd);
inp_h = inp->embd;
res->t_inp_embd = inp->embd;
res->add_input(std::move(inp));
}
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens);
x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone));
cb(x, "inp_embd_target", -1);
ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0);
cb(xh, "inp_xh", -1);
ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh);
cb(cur, "pre_proj", -1);
auto * inp_attn = build_attn_inp_kv_iswa();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer_nextn; ++il) {
const bool is_swa = hparams.is_swa(il);
const int64_t n_embd_head = hparams.n_embd_head_k(il);
const int64_t n_head = hparams.n_head(il);
const float freq_base_l = model.get_rope_freq_base(cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
const int n_rot_l = hparams.n_rot(il);
ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur_norm, "attn_norm", il);
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs;
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig,
freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr,
Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
if (il == n_layer_nextn - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL);
cb(attn_out, "attn_out", il);
cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, nullptr, nullptr,
model.layers[il].ffn_gate, nullptr, nullptr,
model.layers[il].ffn_down, nullptr, nullptr,
nullptr,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
cur = ggml_add(ctx0, cur, attn_out);
cur = ggml_mul(ctx0, cur, model.layers[il].out_scale);
cb(cur, "out_scaled", il);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
ggml_tensor * logits = build_lora_mm(model.output, cur);
cb(logits, "result_output", -1);
res->t_logits = logits;
ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur);
cb(h_next, "h_nextn", -1);
res->t_h_nextn = h_next;
ggml_build_forward_expand(gf, logits);
ggml_build_forward_expand(gf, h_next);
}
+18 -4
View File
@@ -155,12 +155,14 @@ public:
}
virtual ~llm_graph_input_logits_bias() = default;
void set_input(const llama_ubatch *) override {
void set_input(const llama_ubatch * /*ubatch*/) override {
const int64_t n_vocab = arr.size();
ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias));
}
// bool can_reuse(const llm_graph_params & params) override;
bool can_reuse(const llm_graph_params & /*params*/) override {
return true;
}
ggml_tensor * logits_bias = nullptr; // F32 [n_vocab]
@@ -270,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
}
// TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing
if (il == n_layer - 1 && inp_out_ids) {
// keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token)
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
@@ -370,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens]
// TODO @ngxson : improve this
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) {
inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids);
}
@@ -401,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para
model.output_norm, nullptr,
LLM_NORM_RMS, -1);
// Expose the post-output-norm hidden state (the LM-head input feature) so that
// MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the
// recurrent h input. This matches the reference (transformers/vLLM/SGLang),
// which feeds the drafter the target's post-final-norm hidden state.
cb(cur, "h_nextn", -1);
res->t_h_nextn = cur;
if (!cparams.embeddings_nextn_masked && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cb(cur, "result_norm", -1);
res->t_embd = cur;
+13
View File
@@ -822,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base {
};
struct llama_model_gemma4_assistant : public llama_model_base {
llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
void load_arch_tensors(llama_model_loader & ml) override;
struct graph : public llm_graph_context {
graph(const llama_model & model, const llm_graph_params & params);
};
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override;
};
struct llama_model_gemma_embedding : public llama_model_base {
llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {}
void load_arch_hparams(llama_model_loader & ml) override;
+5
View File
@@ -7771,6 +7771,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 384, 1, 1}, {3, 384, 384, 1}, 1, 0, 1, 0, 1, 0, false));
for (int s0 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int d0 : {1, 3}) {
@@ -8525,6 +8526,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// gpt-oss issue with Vulkan mmq_id
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
for (ggml_type type_a : all_types) {
test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a)));
}
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
+133 -175
View File
@@ -1825,6 +1825,104 @@ static void test_convert_responses_to_chatcmpl() {
}
}
// Shared LFM2 parser cases - all variants use one output format and parser
static void test_lfm2_parser(const std::string & template_path, bool detailed_debug) {
auto tst = peg_tester(template_path, detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Python literals become JSON
tst.test("<|tool_call_start|>[toggle(enabled=True)]<|tool_call_end|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
.run();
tst.test("<|tool_call_start|>[set_nullable(value=None)]<|tool_call_end|>")
.tools({ nullable_tool })
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
.run();
// Nested Python literal
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": True, \"count\": 3})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "count": 3}})"))
.run();
// JSON literals are accepted too
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": true, \"note\": null})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "note": null}})"))
.run();
// Dotted function name with structured args
tst.test("<|tool_call_start|>[Calendar.create_event(title=\"demo\", participants=[\"Alice\", \"Bob\"], "
"metadata={\"priority\": \"high\", \"reminder\": true})]<|tool_call_end|>")
.tools({ calendar_create_event_tool })
.expect(message_with_tool_calls(
"Calendar.create_event",
R"({"title": "demo", "participants": ["Alice", "Bob"], "metadata": {"priority": "high", "reminder": true}})"))
.run();
// Markdown links stay content
tst.test("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org).")
.tools({ get_time_tool })
.expect(simple_assist_msg("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org)."))
.run();
// Python tool with multiline code in string
tst.test("<|tool_call_start|>[python(code=\"def hello():\\n print('hey')\")]<|tool_call_end|>")
.tools({ python_tool })
.expect_tool_calls({
{ "python", R"#({"code": "def hello():\\n print('hey')"})#", "" }
})
.run();
// Content before tool call (no reasoning)
tst.test("Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
))
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({ special_function_tool, special_function_tool_with_optional_param })
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
}
static void test_template_output_peg_parsers(bool detailed_debug) {
LOG_DBG("%s\n", __func__);
@@ -4038,49 +4136,30 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.run();
}
// LFM2-8B-A1B tests - uses <|tool_list_start|>/<|tool_list_end|> and <|tool_call_start|>[name(args)]<|tool_call_end|>
for (const char * tmpl : {
"models/templates/LFM2-8B-A1B.jinja",
"models/templates/LFM2.5-Instruct.jinja",
"models/templates/LFM2.5-8B-A1B.jinja",
}) {
test_lfm2_parser(tmpl, detailed_debug);
}
// Thinking cases only apply to LFM2.5-8B-A1B, the one LFM2 template that emits <think>
{
auto tst = peg_tester("models/templates/LFM2-8B-A1B.jinja", detailed_debug);
auto tst = peg_tester("models/templates/LFM2.5-8B-A1B.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Reasoning is parsed independent of enable_thinking
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Tool call with reasoning (enable_thinking=true)
// Tool call with reasoning
tst.test("<think>I'm\nthinking</think><|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with reasoning and content
tst.test("<think>I need to call a function</think>"
"Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.enable_thinking(true)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
@@ -4088,32 +4167,9 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
))
.run();
// Python tool with multiline code in string
tst.test("<|tool_call_start|>[python(code=\"def hello():\\n print('hey')\")]<|tool_call_end|>")
.tools({ python_tool })
.expect_tool_calls({
{ "python", R"#({"code": "def hello():\\n print('hey')"})#", "" }
})
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// fake tool call marker in reasoning
tst.test(
"<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
// Fake tool call marker inside reasoning is not parsed as a call
tst.test("<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
"<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
@@ -4122,127 +4178,21 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
})
.run();
// Continuation tests
tst.test("world!\nWhat's up?")
// enable_thinking=false still captures emitted reasoning
tst.test("<think>I'm\nthinking</think>Hello, world!\nWhat's up?")
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.messages({ message_user, message_assist_prefill_content })
.add_generation_prompt(false)
.continue_final_message(COMMON_CHAT_CONTINUATION_CONTENT)
.expect_reasoning("I'm thinking")
.expect_content("Hello, world!\nWhat's up?")
.expect(message_assist_thoughts)
.run();
tst.test(" thinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
.messages({ message_user, message_assist_prefill_reasoning })
.add_generation_prompt(false)
.continue_final_message(COMMON_CHAT_CONTINUATION_REASONING)
.expect_reasoning("I'm thinking")
.expect_content("Hello, world!\nWhat's up?")
.run();
}
// LFM2.5 tests - format <|tool_call_start|>[name(args)]<|tool_call_end|>
{
auto tst = peg_tester("models/templates/LFM2.5-Instruct.jinja", detailed_debug);
// Basic content only
tst.test("Hello, world!\nWhat's up?").expect(message_assist).run();
// Single tool call without reasoning
tst.test("<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.tools({ special_function_tool })
.expect(message_assist_call)
.run();
// Tool call with string argument
tst.test("<|tool_call_start|>[get_time(city=\"XYZCITY\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}"))
.run();
// Python literals become JSON.
tst.test("<|tool_call_start|>[toggle(enabled=True)]<|tool_call_end|>")
.tools({ toggle_tool })
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
.run();
tst.test("<|tool_call_start|>[set_nullable(value=None)]<|tool_call_end|>")
.tools({ nullable_tool })
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
.run();
// Nested Python literal.
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": True, \"count\": 3})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "count": 3}})"))
.run();
// JSON literals are accepted too.
tst.test("<|tool_call_start|>[set_config(config={\"enabled\": true, \"note\": null})]<|tool_call_end|>")
.tools({ config_tool })
.expect(message_with_tool_calls("set_config", R"({"config": {"enabled": true, "note": null}})"))
.run();
// Dotted function name with structured args.
tst.test("<|tool_call_start|>[Calendar.create_event(title=\"demo\", participants=[\"Alice\", \"Bob\"], "
"metadata={\"priority\": \"high\", \"reminder\": true})]<|tool_call_end|>")
.tools({ calendar_create_event_tool })
.expect(message_with_tool_calls(
"Calendar.create_event",
R"({"title": "demo", "participants": ["Alice", "Bob"], "metadata": {"priority": "high", "reminder": true}})"))
.run();
// Markdown links stay content.
tst.test("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org).")
.tools({ get_time_tool })
.expect(simple_assist_msg("Use this format: [link text](url). Example: [Wikipedia](https://www.wikipedia.org)."))
.run();
// Tool call with reasoning (enable_thinking=true)
tst.test("<think>I'm\nthinking</think><|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
.enable_thinking(true)
.enable_thinking(false)
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.tools({ special_function_tool })
.expect(message_assist_call_thoughts)
.run();
// Multiple tool calls (parallel)
tst.test("<|tool_call_start|>[special_function(arg1=1), special_function_with_opt(arg1=1, arg2=2)]<|tool_call_end|>")
.parallel_tool_calls(true)
.tools({
special_function_tool, special_function_tool_with_optional_param
})
.expect_tool_calls({
{ "special_function", R"({"arg1": 1})", {} },
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} },
})
.run();
// Tool call with content before tool call
tst.test("Let me check the time.<|tool_call_start|>[get_time(city=\"Paris\")]<|tool_call_end|>")
.tools({ get_time_tool })
.expect(message_with_reasoning_content_and_multiple_tool_calls(
"", "Let me check the time.", { { "get_time", "{\"city\":\"Paris\"}" } }
))
.run();
// Partial tool call (streaming)
tst.test("<|tool_call_start|>[special_function(arg1=")
.tools({ special_function_tool })
.is_partial(true)
.expect(simple_assist_msg("", "", "special_function", "{\"arg1\": "))
.run();
// Tool call with empty arguments
tst.test("<|tool_call_start|>[empty_args()]<|tool_call_end|>")
.tools({ empty_args_tool })
.expect(simple_assist_msg("", "", "empty_args", "{}"))
.run();
// Continuation tests
// Continuation: prefill content
tst.test("world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
@@ -4253,6 +4203,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
.expect_content("Hello, world!\nWhat's up?")
.run();
// Continuation: prefill reasoning
tst.test(" thinking</think>Hello, world!\nWhat's up?")
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
.enable_thinking(true)
@@ -5478,18 +5429,25 @@ static void test_template_generation_prompt() {
check(tmpls, continuation_reasoning(), "<|im_assistant|>assistant<|im_middle|><think>I'm");
}
{
auto tmpls = read_templates("models/templates/LFM2-8B-A1B.jinja");
for (const char * tmpl : {
"models/templates/LFM2-8B-A1B.jinja",
"models/templates/LFM2.5-Instruct.jinja",
"models/templates/LFM2.5-8B-A1B.jinja",
}) {
auto tmpls = read_templates(tmpl);
check(tmpls, basic(), "<|im_start|>assistant\n");
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>I'm thinking</think>Hello, ");
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>I'm");
}
{
auto tmpls = read_templates("models/templates/LFM2.5-Instruct.jinja");
check(tmpls, basic(), "<|im_start|>assistant\n");
check(tmpls, continuation_content(), "<|im_start|>assistant\n<think>I'm thinking</think>Hello, ");
check(tmpls, continuation_reasoning(), "<|im_start|>assistant\n<think>I'm");
// 8B-A1B renders prior-turn reasoning via the "thinking" field
auto tmpls = read_templates("models/templates/LFM2.5-8B-A1B.jinja");
common_chat_templates_inputs inputs;
inputs.messages = { message_user, message_assist_call_thoughts, tool_msg };
inputs.add_generation_prompt = true;
auto params = common_chat_templates_apply(tmpls.get(), inputs);
assert_contains(params.prompt, "<think>I'm\nthinking</think>");
}
{
+3 -3
View File
@@ -392,7 +392,7 @@ static bool arch_supported(const llm_arch arch) {
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
return false; // FIXME CUDA backend crashes.
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
return false; // FIXME @ngxson
}
if (arch == LLM_ARCH_LLAMA_EMBED || arch == LLM_ARCH_GEMMA_EMBEDDING || arch == LLM_ARCH_T5ENCODER) {
@@ -447,7 +447,7 @@ static int save_models(const llm_arch target_arch, const size_t seed, const ggml
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
for (bool moe : {false, true}) {
@@ -550,7 +550,7 @@ static int test_backends(const llm_arch target_arch, const size_t seed, const gg
if (target_arch != LLM_ARCH_UNKNOWN && arch != target_arch) {
continue;
}
if (arch == LLM_ARCH_GEMMA4) {
if (arch == LLM_ARCH_GEMMA4 || arch == LLM_ARCH_GEMMA4_ASSISTANT) {
continue; // FIXME: ISWA KV cache initialization needs more fixture params
}
+6
View File
@@ -2,6 +2,7 @@
#include <assert.h>
#include "mtmd.h"
#include "mtmd-helper.h"
int main(void) {
printf("\n\nTesting libmtmd C API...\n");
@@ -17,6 +18,11 @@ int main(void) {
return 1;
}
// simple test for the helper
size_t n_tokens_total = mtmd_helper_get_n_tokens(chunks);
printf("Total tokens in chunks: %zu\n", n_tokens_total);
assert(n_tokens_total > 0);
size_t n_chunks = mtmd_input_chunks_size(chunks);
printf("Number of chunks: %zu\n", n_chunks);
assert(n_chunks > 0);
+19 -3
View File
@@ -128,7 +128,18 @@ struct cli_context {
console::spinner::start();
server_task_result_ptr result = rd.next(should_stop);
console::spinner::stop();
while (true) {
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
if (res_partial && res_partial->is_begin) {
// this is the "send 200 status to client" signal in streaming mode
// skip, do not stop the spinner
result = rd.next(should_stop);
} else {
console::spinner::stop();
break;
}
}
std::string curr_content;
bool is_thinking = false;
@@ -224,7 +235,7 @@ struct cli_context {
};
// TODO?: Make this reusable, enums, docs
static const std::array<std::string_view, 7> cmds = {
static const std::array<std::string_view, 8> cmds = {
"/audio ",
"/clear",
"/exit",
@@ -232,6 +243,7 @@ static const std::array<std::string_view, 7> cmds = {
"/image ",
"/read ",
"/regen",
"/video ",
};
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
@@ -446,6 +458,9 @@ int llama_cli(int argc, char ** argv) {
if (inf.has_inp_audio) {
console::log(" /audio <file> add an audio file\n");
}
if (inf.has_inp_video) {
console::log(" /video <file> add a video file\n");
}
console::log("\n");
// interactive loop
@@ -542,7 +557,8 @@ int llama_cli(int argc, char ** argv) {
continue;
} else if (
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
std::string fname = string_strip(buffer.substr(7));
std::string marker = ctx_cli.load_input_file(fname, true);
+7
View File
@@ -1,5 +1,8 @@
# mtmd
set(MTMD_VIDEO ON CACHE BOOL "enable video support in mtmd (requires ffmpeg binary in PATH)")
# TODO: add MTMD_VIDEO_METHOD in the future to select between ffmpeg and other backends
find_package(Threads REQUIRED)
add_library(mtmd
@@ -63,6 +66,10 @@ target_include_directories(mtmd PRIVATE ../..)
target_include_directories(mtmd PRIVATE ../../vendor)
target_compile_features (mtmd PRIVATE cxx_std_17)
if (MTMD_VIDEO)
target_compile_definitions(mtmd PRIVATE MTMD_VIDEO)
endif()
if (BUILD_SHARED_LIBS)
set_target_properties (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(mtmd PRIVATE LLAMA_BUILD)
+14 -5
View File
@@ -77,6 +77,7 @@ struct mtmd_cli_context {
int n_batch;
mtmd::bitmaps bitmaps;
std::vector<mtmd_helper::video_ptr> videos;
// chat template
common_chat_templates_ptr tmpls;
@@ -166,11 +167,14 @@ struct mtmd_cli_context {
}
bool load_media(const std::string & fname) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(ctx_vision.get(), fname.c_str(), false));
if (!bmp.ptr) {
auto res = mtmd_helper_bitmap_init_from_file(ctx_vision.get(), fname.c_str(), false);
if (!res.bitmap) {
return false;
}
bitmaps.entries.push_back(std::move(bmp));
bitmaps.entries.emplace_back(res.bitmap);
if (res.video_ctx) {
videos.emplace_back(res.video_ctx);
}
return true;
}
};
@@ -253,6 +257,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
}
ctx.bitmaps.entries.clear();
ctx.videos.clear();
llama_pos new_n_past;
if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
@@ -373,6 +378,9 @@ int main(int argc, char ** argv) {
if (mtmd_support_audio(ctx.ctx_vision.get())) {
LOG("\n /audio <path> load an audio");
}
if (mtmd_helper_support_video(ctx.ctx_vision.get())) {
LOG("\n /video <path> load a video");
}
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
@@ -407,14 +415,15 @@ int main(int argc, char ** argv) {
g_is_generating = true;
bool is_image = line == "/image" || line.find("/image ") == 0;
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
if (is_image || is_audio) {
bool is_video = line == "/video" || line.find("/video ") == 0;
if (is_image || is_audio || is_video) {
if (line.size() < 8) {
LOG_ERR("ERR: Missing media filename\n");
continue;
}
std::string media_path = line.substr(7);
if (ctx.load_media(media_path)) {
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : is_audio ? "audio" : "video");
content += mtmd_default_marker();
}
// else, error is already printed by libmtmd
+490 -16
View File
@@ -36,6 +36,11 @@
#error "mtmd-helper is a public library outside of mtmd. it must not include internal headers"
#endif
#ifdef MTMD_VIDEO
#include "sheredom/subprocess.h"
#include <thread>
#endif
//
// internal logging functions
//
@@ -79,6 +84,7 @@ struct mtmd_helper_logger {
}
} g_logger;
#define LOG_DBG(...) g_logger.log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
@@ -478,42 +484,94 @@ static bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int
} // namespace audio_helpers
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder) {
// Computes FNV-1a hash of the data
static std::string fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
for (size_t i = 0; i < len; ++i) {
hash ^= data[i];
hash *= fnv_prime;
}
return std::to_string(hash);
}
mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder) {
// calculate the hash if needed
std::string id;
mtmd_bitmap * result = nullptr;
if (!placeholder) {
id = fnv_hash(buf, len);
}
if (audio_helpers::is_audio_file((const char *)buf, len)) {
std::vector<float> pcmf32;
const int sample_rate = mtmd_get_audio_sample_rate(ctx);
if (sample_rate < 0) {
LOG_ERR("This model does not support audio input\n");
return nullptr;
return {nullptr, nullptr};
}
if (!audio_helpers::decode_audio_from_buf(buf, len, sample_rate, pcmf32)) {
LOG_ERR("Unable to read WAV audio file from buffer\n");
return nullptr;
return {nullptr, nullptr};
}
return mtmd_bitmap_init_from_audio(pcmf32.size(), placeholder ? nullptr : pcmf32.data());
result = mtmd_bitmap_init_from_audio(pcmf32.size(), placeholder ? nullptr : pcmf32.data());
mtmd_bitmap_set_id(result, id.empty() ? nullptr : id.c_str());
return {result, nullptr};
}
// otherwise, we assume it's an image
mtmd_bitmap * result = nullptr;
{
if (!result) {
int nx, ny, nc;
auto * data = stbi_load_from_memory(buf, len, &nx, &ny, &nc, 3);
if (!data) {
LOG_ERR("%s: failed to decode image bytes\n", __func__);
return nullptr;
if (data) {
result = mtmd_bitmap_init(nx, ny, placeholder ? nullptr : data);
mtmd_bitmap_set_id(result, id.empty() ? nullptr : id.c_str());
stbi_image_free(data);
return {result, nullptr};
}
result = mtmd_bitmap_init(nx, ny, placeholder ? nullptr : data);
stbi_image_free(data);
// otherwise, fallthrough to video decoding (if supported)
}
return result;
// last try: load as video
#ifdef MTMD_VIDEO
if (!result) {
auto params = mtmd_helper_video_init_params_default();
auto video_ctx = mtmd_helper_video_init_from_buf(ctx, buf, len, params);
if (!video_ctx) {
LOG_ERR("%s: failed to decode buffer as either image/audio/video\n", __func__);
return {nullptr, nullptr};
}
result = mtmd_bitmap_init_lazy(ctx,
id.empty() ? nullptr : id.c_str(),
video_ctx,
[](size_t, void * user_data, mtmd_bitmap ** out_bitmap, char ** out_text) -> int {
auto * vctx = static_cast<mtmd_helper_video *>(user_data);
char * text = nullptr;
int ret = mtmd_helper_video_read_next(vctx, out_bitmap, &text);
*out_text = text; // heap-allocated by read_next; freed automatically by mtmd
return ret;
});
return {result, video_ctx};
}
#else
if (!result) {
LOG_ERR("%s: failed to decode buffer as either image or audio (video support not compiled in)\n", __func__);
return {nullptr, nullptr};
}
#endif
// should not reach here
return {nullptr, nullptr};
}
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) {
mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) {
std::vector<unsigned char> buf;
FILE * f = fopen(fname, "rb");
if (!f) {
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
return nullptr;
return {nullptr, nullptr};
}
fseek(f, 0, SEEK_END);
@@ -522,7 +580,7 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
if (file_size < 0) {
LOG_ERR("Failed to get file size of %s\n", fname);
fclose(f);
return nullptr;
return {nullptr, nullptr};
}
buf.resize(file_size);
@@ -530,9 +588,425 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
fclose(f);
if (n_read != (size_t)file_size) {
LOG_ERR("Failed to read entire file %s", fname);
return nullptr;
return {nullptr, nullptr};
}
return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size(), placeholder);
}
bool mtmd_helper_support_video(mtmd_context * ctx) {
#ifdef MTMD_VIDEO
return mtmd_support_vision(ctx);
#else
return false;
#endif
}
//
// Video input helpers
//
#ifdef MTMD_VIDEO
struct mtmd_helper_video {
mtmd_context * mctx;
std::string path;
std::vector<uint8_t> input_buf; // non-empty when initialized from buffer
std::string ffmpeg_bin;
std::string ffprobe_bin;
float fps_target = 0.0f;
mtmd_helper_video_info info = {};
struct subprocess_s proc = {};
bool proc_alive = false;
int32_t current_frame = 0;
std::thread feeder_thread;
std::string prompt_start = "Video:";
int32_t timestamp_interval_ms = 5000; // emit a timestamp text every N ms (0 = disabled)
float next_timestamp_ms = 0.0f; // next elapsed-ms threshold at which to emit
std::vector<uint8_t> frame_buf;
std::string pending_text; // text queued to be returned before the next frame
bool start_emitted = false;
bool is_buf_input() const { return !input_buf.empty(); }
// must run in a separate thread alongside stdout reading to avoid pipe deadlock
void feed_stdin(struct subprocess_s * sp) {
FILE * f = subprocess_stdin(sp);
if (!f) {
LOG_DBG("%s: subprocess has no stdin pipe\n", __func__);
return;
}
LOG_DBG("%s: feeding %zu bytes to stdin\n", __func__, input_buf.size());
size_t written = fwrite(input_buf.data(), 1, input_buf.size(), f);
LOG_DBG("%s: wrote %zu bytes, closing stdin\n", __func__, written);
fclose(f);
}
bool probe(float fps_target_arg) {
const char * input_arg = is_buf_input() ? "pipe:0" : path.c_str();
const char * cmd[] = {
ffprobe_bin.c_str(),
"-v", "quiet",
"-show_entries", "stream=width,height,r_frame_rate,nb_frames,duration",
"-select_streams", "v:0",
"-of", "default=noprint_wrappers=1",
input_arg,
nullptr,
};
LOG_DBG("%s: launching:", __func__);
for (size_t i = 0; cmd[i]; i++) { LOG_DBG(" %s", cmd[i]); }
LOG_DBG("\n");
struct subprocess_s fprobe;
if (subprocess_create(cmd,
subprocess_option_search_user_path | subprocess_option_inherit_environment,
&fprobe) != 0) {
LOG_ERR("%s: failed to launch ffprobe\n", __func__);
return false;
}
std::thread probe_feeder;
if (is_buf_input()) {
probe_feeder = std::thread([this, &fprobe]() { feed_stdin(&fprobe); });
}
uint32_t width = 0;
uint32_t height = 0;
float orig_fps = 0.0f;
float duration = -1.0f;
int32_t n_frames_orig = -1;
char line[256];
FILE * fp = subprocess_stdout(&fprobe);
while (fgets(line, sizeof(line), fp)) {
char * eq = strchr(line, '=');
if (!eq) continue;
*eq = '\0';
const char * key = line;
const char * val = eq + 1;
char * nl = (char *)strchr(val, '\n');
if (nl) *nl = '\0';
if (strcmp(key, "width") == 0) {
width = (uint32_t)atoi(val);
} else if (strcmp(key, "height") == 0) {
height = (uint32_t)atoi(val);
} else if (strcmp(key, "r_frame_rate") == 0) {
orig_fps = parse_rational(val);
} else if (strcmp(key, "nb_frames") == 0 && strcmp(val, "N/A") != 0) {
n_frames_orig = atoi(val);
} else if (strcmp(key, "duration") == 0 && strcmp(val, "N/A") != 0) {
duration = (float)atof(val);
}
}
if (probe_feeder.joinable()) {
probe_feeder.join();
}
int ret_code;
subprocess_join(&fprobe, &ret_code);
subprocess_destroy(&fprobe);
if (width == 0 || height == 0 || orig_fps <= 0.0f) {
return false;
}
if (duration < 0.0f && n_frames_orig > 0) {
duration = (float)n_frames_orig / orig_fps;
}
fps_target = fps_target_arg > 0.0f ? fps_target_arg : orig_fps;
info.width = width;
info.height = height;
info.fps = fps_target;
LOG_DBG("%s: %ux%u fps=%.2f duration=%.2fs n_frames=%d\n",
__func__, width, height, fps_target, duration, info.n_frames);
info.n_frames = duration > 0.0f ? (int32_t)(duration * fps_target + 0.5f) : -1;
frame_buf.resize((size_t)width * height * 3);
return true;
}
bool start_ffmpeg(float seek_seconds) {
char seek_buf[64];
char fps_buf[64];
std::vector<const char *> cmd;
cmd.push_back(ffmpeg_bin.c_str());
if (!is_buf_input() && seek_seconds > 0.0f) {
// input-side seek: fast, keyframe-accurate; only valid for seekable file inputs
snprintf(seek_buf, sizeof(seek_buf), "%.6f", seek_seconds);
cmd.push_back("-ss");
cmd.push_back(seek_buf);
}
cmd.push_back("-i");
// cache:pipe:0 wraps stdin with a seekable in-memory cache, letting ffmpeg seek
// backwards for container headers (e.g. MP4 moov atom at end of file)
cmd.push_back(is_buf_input() ? "cache:pipe:0" : path.c_str());
if (seek_seconds > 0.0f && is_buf_input()) {
// output-side seek: frame-accurate but decodes and discards frames up to seek point
snprintf(seek_buf, sizeof(seek_buf), "%.6f", seek_seconds);
cmd.push_back("-ss");
cmd.push_back(seek_buf);
}
if (fps_target > 0.0f) {
snprintf(fps_buf, sizeof(fps_buf), "fps=%.6f", fps_target);
cmd.push_back("-vf");
cmd.push_back(fps_buf);
}
cmd.push_back("-f");
cmd.push_back("rawvideo");
cmd.push_back("-pix_fmt");
cmd.push_back("rgb24");
cmd.push_back("pipe:1");
cmd.push_back("-loglevel");
cmd.push_back("error");
cmd.push_back(nullptr);
LOG_DBG("%s: launching:", __func__);
for (size_t i = 0; cmd[i]; i++) {
LOG_DBG(" %s", cmd[i]);
}
LOG_DBG("\n");
int ret = subprocess_create(
cmd.data(),
subprocess_option_search_user_path | subprocess_option_inherit_environment,
&proc);
proc_alive = (ret == 0);
LOG_DBG("%s: subprocess_create ret=%d proc_alive=%d\n", __func__, ret, (int)proc_alive);
if (proc_alive && is_buf_input()) {
LOG_DBG("%s: starting feeder thread for %zu-byte buffer\n", __func__, input_buf.size());
feeder_thread = std::thread([this]() { feed_stdin(&proc); });
}
return proc_alive;
}
void stop_ffmpeg() {
if (proc_alive) {
subprocess_terminate(&proc);
subprocess_destroy(&proc);
proc_alive = false;
}
if (feeder_thread.joinable()) {
feeder_thread.join();
}
}
mtmd_bitmap * read_next_frame() {
if (!proc_alive) return nullptr;
FILE * fp = subprocess_stdout(&proc);
const size_t frame_size = (size_t)info.width * info.height * 3;
LOG_DBG("%s: reading frame %d, expecting %zu bytes (%ux%u)\n",
__func__, current_frame, frame_size, info.width, info.height);
size_t total_read = 0;
while (total_read < frame_size) {
size_t n = fread(frame_buf.data() + total_read, 1, frame_size - total_read, fp);
if (n == 0) {
// clean EOF only if no bytes read yet; partial frame is an error
LOG_DBG("%s: fread returned 0 after %zu/%zu bytes (ferror=%d)\n",
__func__, total_read, frame_size, ferror(fp));
proc_alive = false;
return nullptr;
}
total_read += n;
}
LOG_DBG("%s: frame %d read OK\n", __func__, current_frame);
current_frame++;
return mtmd_bitmap_init(info.width, info.height, frame_buf.data());
}
int32_t read_next(mtmd_bitmap ** out_bitmap, char ** out_text) {
*out_bitmap = nullptr;
*out_text = nullptr;
if (!pending_text.empty()) {
*out_text = strdup(pending_text.c_str());
pending_text.clear();
return *out_text ? 0 : -2;
}
LOG_DBG("%s: proc_alive=%d start_emitted=%d current_frame=%d\n",
__func__, (int)proc_alive, (int)start_emitted, current_frame);
if (!proc_alive) {
return (current_frame == 0) ? -2 : -1;
}
if (!start_emitted) {
start_emitted = true;
if (!prompt_start.empty()) {
*out_text = strdup(prompt_start.c_str());
return *out_text ? 0 : -2;
}
}
mtmd_bitmap * frame = read_next_frame();
if (!frame) return -1;
*out_bitmap = frame;
if (timestamp_interval_ms > 0) {
// current_frame was already incremented by read_next_frame(); undo for elapsed calc
float elapsed_ms = (float)(current_frame - 1) / info.fps * 1000.0f;
if (elapsed_ms >= next_timestamp_ms) {
char ts_buf[32];
float elapsed_s = elapsed_ms / 1000.0f;
int minutes = (int)(elapsed_s / 60);
float seconds = elapsed_s - minutes * 60.0f;
snprintf(ts_buf, sizeof(ts_buf), "[%dm%.2fs]", minutes, seconds);
pending_text = ts_buf;
next_timestamp_ms += (float)timestamp_interval_ms;
}
}
return 0;
}
static float parse_rational(const char * s) {
int num = 0, den = 1;
if (sscanf(s, "%d/%d", &num, &den) == 2 && den > 0) {
return (float)num / (float)den;
}
float val;
if (sscanf(s, "%f", &val) == 1) {
return val;
}
return 0.0f;
}
};
#endif
mtmd_helper_video_init_params mtmd_helper_video_init_params_default() {
return {
/* fps_target */ 4.0f,
/* ffmpeg_bin_dir */ nullptr,
/* timestamp_interval_ms */ 5000,
};
}
static std::string video_resolve_bin(const char * bin_dir, const char * name) {
if (!bin_dir || bin_dir[0] == '\0') {
return name; // rely on PATH
}
std::string result = bin_dir;
char last = result.back();
if (last != '/' && last != '\\') {
#ifdef _WIN32
result += '\\';
#else
result += '/';
#endif
}
result += name;
#ifdef _WIN32
result += ".exe";
#endif
return result;
}
mtmd_helper_video * mtmd_helper_video_init(
mtmd_context * mctx,
const char * path,
mtmd_helper_video_init_params params) {
#ifdef MTMD_VIDEO
auto * ctx = new mtmd_helper_video();
ctx->mctx = mctx;
ctx->path = path;
ctx->ffmpeg_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffmpeg");
ctx->ffprobe_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffprobe");
ctx->timestamp_interval_ms = params.timestamp_interval_ms;
if (!ctx->probe(params.fps_target)) {
LOG_ERR("%s: ffprobe failed for '%s' (is ffprobe in PATH?)\n", __func__, path);
delete ctx;
return nullptr;
}
if (!ctx->start_ffmpeg(0.0f)) {
LOG_ERR("%s: failed to start ffmpeg for '%s' (is ffmpeg in PATH?)\n", __func__, path);
delete ctx;
return nullptr;
}
return ctx;
#else
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
return nullptr;
#endif
}
mtmd_helper_video * mtmd_helper_video_init_from_buf(
mtmd_context * mctx,
const unsigned char * buf, size_t len,
mtmd_helper_video_init_params params) {
#ifdef MTMD_VIDEO
auto * ctx = new mtmd_helper_video();
ctx->mctx = mctx;
ctx->input_buf.assign(buf, buf + len);
ctx->ffmpeg_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffmpeg");
ctx->ffprobe_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffprobe");
ctx->timestamp_interval_ms = params.timestamp_interval_ms;
if (!ctx->probe(params.fps_target)) {
LOG_ERR("%s: ffprobe failed on buffer (is ffprobe in PATH?)\n", __func__);
delete ctx;
return nullptr;
}
if (!ctx->start_ffmpeg(0.0f)) {
LOG_ERR("%s: failed to start ffmpeg on buffer (is ffmpeg in PATH?)\n", __func__);
delete ctx;
return nullptr;
}
return ctx;
#else
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
return nullptr;
#endif
}
void mtmd_helper_video_free(mtmd_helper_video * ctx) {
#ifdef MTMD_VIDEO
if (!ctx) return;
ctx->stop_ffmpeg();
delete ctx;
#else
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
#endif
}
mtmd_helper_video_info mtmd_helper_video_get_info(const mtmd_helper_video * ctx) {
#ifdef MTMD_VIDEO
return ctx->info;
#else
GGML_ASSERT(false && "video is not supported in this build (MTMD_VIDEO is set to OFF)");
#endif
}
int32_t mtmd_helper_video_read_next(mtmd_helper_video * ctx,
mtmd_bitmap ** out_bitmap, char ** out_text) {
#ifdef MTMD_VIDEO
if (!ctx) return -2;
return ctx->read_next(out_bitmap, out_text);
#else
GGML_ASSERT(false && "video is not supported in this build (MTMD_VIDEO is set to OFF)");
#endif
}
+79 -3
View File
@@ -20,25 +20,39 @@ extern "C" {
// BREAKING CHANGES are expected.
//
struct mtmd_helper_video;
typedef struct mtmd_helper_video mtmd_helper_video;
// Set callback for all future logging events.
// If this is not called, or NULL is supplied, everything is output on stderr.
// Note: this also call mtmd_log_set() internally
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
// Returns true if this build includes video support (MTMD_VIDEO was ON at compile time).
MTMD_API bool mtmd_helper_support_video(mtmd_context * ctx);
struct mtmd_helper_bitmap_wrapper {
mtmd_bitmap * bitmap;
mtmd_helper_video * video_ctx;
};
// helper function to construct a mtmd_bitmap from a file
// it calls mtmd_helper_bitmap_init_from_buf() internally
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder);
MTMD_API struct mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder);
// helper function to construct a mtmd_bitmap from a buffer containing a file
// supported formats:
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
// audio: formats supported by miniaudio: wav, mp3, flac
// note: audio files will be auto-detected based on magic bytes
// note:
// - for now, video input is only supported via C++ helper functions
// - audio files will be auto-detected based on magic bytes
// - output bitmap will have FNV hash as the ID
// returns nullptr on failure
// this function is thread-safe
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder);
MTMD_API struct mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder);
// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
@@ -89,6 +103,56 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
int32_t n_batch,
llama_pos * new_n_past);
//
// video input helpers (requires ffmpeg/ffprobe installed on the system)
// the notion of video only exists at the helper level, it is not visible to the core mtmd library
//
// NOTE: this implementation is model-agnostic, it can be used with any vision-capable model
// however, it may not be accurate for some specific models
// (this is expected for now, to keep the implementation simple)
//
struct mtmd_helper_video_info {
uint32_t width;
uint32_t height;
float fps; // effective fps (fps_target if set, else original video fps)
int32_t n_frames; // estimated total frames at effective fps (-1 if unknown)
};
struct mtmd_helper_video_init_params {
float fps_target; // desired output fps; <= 0 means use the video's native fps, defaulted to 4.0f
const char * ffmpeg_bin_dir; // directory containing ffmpeg/ffprobe binaries; NULL means search PATH
int64_t timestamp_interval_ms; // interval for adding timestamp as text chunk (example: "[10m50.5s]"); <= 0 means no timestamp, defaulted to 5000ms
// TODO @ngxson : allow "placeholder" bitmap output for counting tokens
};
MTMD_API struct mtmd_helper_video_init_params mtmd_helper_video_init_params_default(void);
// returns NULL on failure (ffprobe not found, file unreadable, etc.)
MTMD_API mtmd_helper_video * mtmd_helper_video_init(
struct mtmd_context * mctx,
const char * path,
struct mtmd_helper_video_init_params params);
// Same as mtmd_helper_video_init(), but reads from an in-memory buffer.
// The buffer is copied internally; the caller does not need to keep it alive.
// Note: pipe input is not seekable, so seeking will use output-side seeking
// (ffmpeg decodes and discards frames up to the target position).
MTMD_API mtmd_helper_video * mtmd_helper_video_init_from_buf(
struct mtmd_context * mctx,
const unsigned char * buf, size_t len,
struct mtmd_helper_video_init_params params);
MTMD_API void mtmd_helper_video_free(mtmd_helper_video * ctx);
MTMD_API struct mtmd_helper_video_info mtmd_helper_video_get_info(const mtmd_helper_video * ctx);
// Read the next item from the video stream; exactly one of out_bitmap or out_text is set per call.
// *out_bitmap - heap-allocated; caller must free with mtmd_bitmap_free()
// *out_text - heap-allocated (always via strdup/malloc); caller must free with free()
// returns 0 on success, -1 on EOF, -2 on error
MTMD_API int32_t mtmd_helper_video_read_next(mtmd_helper_video * ctx,
mtmd_bitmap ** out_bitmap,
char ** out_text);
#ifdef __cplusplus
} // extern "C"
#endif
@@ -97,4 +161,16 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
// C++ wrappers
//
#ifdef __cplusplus
namespace mtmd_helper {
// video-related C++ wrappers
struct mtmd_helper_video_deleter {
void operator()(mtmd_helper_video * val) { mtmd_helper_video_free(val); }
};
using video_ptr = std::unique_ptr<mtmd_helper_video, mtmd_helper_video_deleter>;
} // namespace mtmd_helper
#endif
#endif
+138 -29
View File
@@ -35,6 +35,10 @@ struct mtmd_bitmap {
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
bool is_audio = false; // true if the bitmap is audio
// lazy-loaded bitmap
mtmd_bitmap_lazy_callback lazy_callback = nullptr;
void * lazy_user_data = nullptr;
mtmd_bitmap(const unsigned char * data, uint32_t nx, uint32_t ny)
: nx(nx), ny(ny), is_audio(false) {
if (data) {
@@ -732,30 +736,111 @@ void mtmd_free(mtmd_context * ctx) {
struct mtmd_tokenizer {
mtmd_context * ctx;
std::vector<const mtmd_bitmap *> bitmaps;
std::string input_text;
bool add_special;
bool parse_special;
const llama_vocab * vocab;
struct part {
std::string text;
const mtmd_bitmap * bitmap;
};
std::vector<part> parts;
// these will be freed when mtmd_tokenizer finishes
std::vector<mtmd::bitmap> bm_from_lazy; // TODO @ngxson : refactor, free bm_from_lazy progressively
std::vector<const char *> text_from_lazy;
mtmd_input_chunks cur;
uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk
~mtmd_tokenizer() {
// note: mtmd::bitmap is already RAII
for (auto & str : text_from_lazy) {
free((void *)str);
}
}
mtmd_tokenizer(mtmd_context * ctx,
const mtmd_input_text * text,
const mtmd_bitmap ** bitmaps,
size_t n_bitmaps) : ctx(ctx), bitmaps(bitmaps, bitmaps + n_bitmaps) {
const mtmd_bitmap ** bmps,
size_t n_bitmaps) : ctx(ctx) {
add_special = text->add_special;
parse_special = text->parse_special;
input_text = text->text;
vocab = ctx->vocab;
std::vector<const mtmd_bitmap *> bitmaps(bmps, bmps + n_bitmaps);
auto parts_str = split_text(input_text, ctx->media_marker);
size_t i_bm = 0;
for (const auto & part : parts_str) {
if (part == ctx->media_marker) {
if (i_bm >= bitmaps.size()) {
throw std::runtime_error(string_format("number of media markers in text (%zu) exceeds number of bitmaps (%zu)", i_bm + 1, bitmaps.size()));
}
parts.push_back({"", bitmaps[i_bm++]});
} else {
parts.push_back({std::move(part), nullptr});
}
}
size_t n_markers = 0;
for (const auto & part : parts) {
if (part.bitmap != nullptr) {
n_markers++;
}
}
if (n_markers != bitmaps.size()) {
throw std::runtime_error(string_format("number of media markers in text (%zu) does not match number of bitmaps (%zu)", n_markers, bitmaps.size()));
}
expand_lazy_bitmaps();
}
void expand_lazy_bitmaps() {
std::vector<part> expanded;
expanded.reserve(parts.size());
for (auto & p : parts) {
if (p.bitmap != nullptr && p.bitmap->lazy_callback) {
LOG_DBG("%s: expanding lazy bitmap\n", __func__);
for (size_t i = 0;; i++) {
char * out_str = nullptr;
mtmd_bitmap * out_bm = nullptr;
int res = p.bitmap->lazy_callback(i,
p.bitmap->lazy_user_data,
&out_bm,
&out_str);
if (out_bm && out_str) {
throw std::runtime_error(string_format("lazy callback cannot return both bitmap and text"));
}
if (res == 0) {
// OK, append the returned chunk; lazy part is not yet added
if (out_bm) {
auto & ptr = bm_from_lazy.emplace_back(out_bm); // remember to free it later
expanded.push_back({"", ptr.ptr.get()});
LOG_DBG("%s: lazy callback returned bitmap with dimensions %d x %d\n", __func__, out_bm->nx, out_bm->ny);
} else if (out_str) {
auto & ptr = text_from_lazy.emplace_back(out_str); // remember to free it later
expanded.push_back({ptr, nullptr});
LOG_DBG("%s: lazy callback returned text: %s\n", __func__, out_str);
}
} else if (res == -1) {
// EOF: lazy part removes itself (not added to expanded)
break;
} else if (res == -2) {
// error
throw std::runtime_error(string_format("lazy callback returned error"));
}
}
} else {
expanded.push_back(std::move(p));
}
}
parts = std::move(expanded);
}
int32_t tokenize(mtmd_input_chunks * output) {
cur.entries.clear();
std::vector<std::string> parts = split_text(input_text, ctx->media_marker);
size_t i_bm = 0; // index of the current bitmap
// [QWEN_VIDEO] handle frame merging for models that support it (i.e. qwen-vl)
int n_merge_frames = 1;
@@ -764,53 +849,50 @@ struct mtmd_tokenizer {
GGML_ASSERT(n_merge_frames <= 2 && "we only support merging maximum 2 images for now; open an issue if this model supports merging more");
}
// Build merged_bitmaps: each entry is a group of 1 or 2 bitmaps.
// For consecutive mergeable bitmap parts, merge them and collapse the second part out of this->parts.
std::vector<std::vector<const mtmd_bitmap *>> merged_bitmaps;
if (n_merge_frames > 1) {
size_t i_bm_scan = 0;
for (size_t i = 0; i < parts.size(); ++i) {
if (parts[i] != ctx->media_marker) {
if (parts[i].bitmap == nullptr) {
continue;
}
if (i + 1 < parts.size()
&& parts[i + 1] == ctx->media_marker
&& i_bm_scan + 1 < bitmaps.size()) {
const mtmd_bitmap * bm_a = bitmaps[i_bm_scan];
const mtmd_bitmap * bm_b = bitmaps[i_bm_scan + 1];
if (i + 1 < parts.size() && parts[i + 1].bitmap != nullptr) {
const mtmd_bitmap * bm_a = parts[i].bitmap;
const mtmd_bitmap * bm_b = parts[i + 1].bitmap;
if (bm_a->can_batch_with(*bm_b)) {
LOG_DBG("%s: merging 2 frames at bitmap index %zu and %zu\n", __func__, i_bm_scan, i_bm_scan + 1);
LOG_DBG("%s: merging 2 frames at part index %zu and %zu\n", __func__, i, i + 1);
merged_bitmaps.push_back({bm_a, bm_b});
parts.erase(parts.begin() + i + 1); // remove the second marker
i_bm_scan += 2;
parts.erase(parts.begin() + i + 1); // collapse the second bitmap part
continue;
}
}
LOG_DBG("%s: no merging for bitmap index %zu\n", __func__, i_bm_scan);
merged_bitmaps.push_back({bitmaps[i_bm_scan]});
++i_bm_scan;
LOG_DBG("%s: no merging for part index %zu\n", __func__, i);
merged_bitmaps.push_back({parts[i].bitmap});
}
} else {
for (size_t i = 0; i < bitmaps.size(); ++i) {
merged_bitmaps.push_back({bitmaps[i]});
for (const auto & p : parts) {
if (p.bitmap != nullptr) {
merged_bitmaps.push_back({p.bitmap});
}
}
}
i_bm = 0;
for (auto & part : parts) {
if (part == ctx->media_marker) {
// this is a marker, we should add the next bitmap
size_t i_bm = 0;
for (const auto & p : parts) {
if (p.bitmap != nullptr) {
if (i_bm >= merged_bitmaps.size()) {
LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
__func__, merged_bitmaps.size(), parts.size() - 1);
return 1;
}
auto & bmps = merged_bitmaps[i_bm++];
auto bmps = merged_bitmaps[i_bm++];
int32_t res = add_media(bmps);
if (res != 0) {
return res;
}
} else {
// this is a text part, we should add it as text
add_text(part, parse_special);
add_text(p.text, parse_special);
}
}
@@ -1236,8 +1318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
const mtmd_input_text * text,
const mtmd_bitmap ** bitmaps,
size_t n_bitmaps) {
mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
return tokenizer.tokenize(output);
try {
mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
return tokenizer.tokenize(output);
} catch (const std::exception & e) {
LOG_ERR("%s: error: %s\n", __func__, e.what());
return 2;
}
}
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
@@ -1373,6 +1460,10 @@ int mtmd_get_audio_sample_rate(const mtmd_context * ctx) {
return clip_get_hparams(ctx->ctx_a)->audio_sample_rate;
}
const char * mtmd_get_marker(const mtmd_context * ctx) {
return ctx->media_marker.c_str();
}
//
// public API functions
//
@@ -1405,10 +1496,16 @@ uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) {
}
const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
if (bitmap->is_placeholder()) {
return nullptr;
}
return bitmap->get_ro_buf().data();
}
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
if (bitmap->is_placeholder()) {
return 0;
}
return bitmap->get_ro_buf().size();
}
@@ -1428,6 +1525,18 @@ void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id) {
}
}
mtmd_bitmap * mtmd_bitmap_init_lazy(mtmd_context * ctx,
const char * id,
void * user_data,
mtmd_bitmap_lazy_callback callback) {
GGML_UNUSED(ctx); // reserved for future use
mtmd_bitmap * bitmap = new mtmd_bitmap(nullptr, 0, 0);
bitmap->lazy_callback = callback;
bitmap->lazy_user_data = user_data;
mtmd_bitmap_set_id(bitmap, id);
return bitmap;
}
void mtmd_bitmap_free(mtmd_bitmap * bitmap) {
if (bitmap) {
delete bitmap;
+31
View File
@@ -128,6 +128,9 @@ MTMD_API bool mtmd_support_audio(const mtmd_context * ctx);
// return -1 if audio is not supported
MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx);
// get the current marker string
MTMD_API const char * mtmd_get_marker(const mtmd_context * ctx);
// mtmd_bitmap
//
// if bitmap is image:
@@ -156,6 +159,34 @@ MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id);
// mtmd_bitmap lazy
//
// this is a special bitmap that:
// - does not hold the actual data
// - can be expanded into one or more chunks (either media to text chunks)
// user must provide a callback to fill in the data when mtmd_tokenize() is called
// this is useful for large video inputs:
// - allow reading video frame by frame, without loading the entire video into memory
// - allow tracking the whole video with a single ID (for example, the file hash)
// set (*out_bitmap) to non-nullptr to emit a bitmap chunk; it will be freed automatically
// set (*out_text) to non-nullptr to emit a text chunk; it must be heap-allocated, null-terminated and will be freed automatically
// either out_bitmap or out_text can be set, but not both
// out_bitmap cannot be another lazy bitmap (no nested lazy allowed)
// return value:
// 0 on success
// -1 on EOF (signal to mtmd_tokenize to move on)
// -2 on error (signal to mtmd_tokenize to abort)
typedef int(* mtmd_bitmap_lazy_callback)(
size_t chunk_idx,
void * user_data,
mtmd_bitmap ** out_bitmap,
char ** out_text);
MTMD_API mtmd_bitmap * mtmd_bitmap_init_lazy(mtmd_context * ctx,
const char * id, // usually set to file hash
void * user_data,
mtmd_bitmap_lazy_callback callback);
// mtmd_input_chunks
//
Binary file not shown.
+4
View File
@@ -1252,6 +1252,10 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
For multimodal input:
- Content type `image_url` and `input_audio` are the same as OAI schema
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
*Examples:*
You can use either Python `openai` library with appropriate checkpoints:
+22 -18
View File
@@ -701,29 +701,19 @@ size_t validate_utf8(const std::string& text) {
return len;
}
// Computes FNV-1a hash of the data
static std::string fnv_hash(const uint8_t * data, size_t len) {
const uint64_t fnv_prime = 0x100000001b3ULL;
uint64_t hash = 0xcbf29ce484222325ULL;
for (size_t i = 0; i < len; ++i) {
hash ^= data[i];
hash *= fnv_prime;
}
return std::to_string(hash);
}
server_tokens process_mtmd_prompt(mtmd_context * mctx, const std::string & prompt, const std::vector<raw_buffer> & files, bool is_placeholder) {
// these will be freed upon going out of scope
mtmd::bitmaps bitmaps;
std::vector<mtmd_helper::video_ptr> videos;
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size(), is_placeholder));
if (!bmp.ptr) {
auto out = mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size(), is_placeholder);
if (!out.bitmap) {
throw std::runtime_error("Failed to load image or audio file");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
bitmaps.entries.emplace_back(out.bitmap);
if (out.video_ctx) {
videos.emplace_back(out.video_ctx);
}
}
// process prompt
std::vector<server_tokens> inputs;
@@ -1023,6 +1013,20 @@ json oaicompat_chat_params_parse(
p["text"] = get_media_marker();
p.erase("input_audio");
} else if (type == "input_video") {
if (!opt.allow_video) {
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
}
json input_video = json_value(p, "input_video", json::object());
std::string data = json_value(input_video, "data", std::string());
auto decoded_data = base64_decode(data); // expected to be base64 encoded
out_files.push_back(decoded_data);
p["type"] = "media_marker";
p["text"] = get_media_marker();
p.erase("input_video");
} else if (type != "text") {
throw std::invalid_argument("unsupported content[].type");
}
+1
View File
@@ -294,6 +294,7 @@ struct server_chat_params {
common_chat_templates_ptr tmpls;
bool allow_image;
bool allow_audio;
bool allow_video;
bool enable_thinking = true;
int reasoning_budget = -1;
std::string reasoning_budget_message;
+21 -9
View File
@@ -1,4 +1,3 @@
#include "server-context.h"
#include "server-chat.h"
#include "server-common.h"
@@ -16,6 +15,11 @@
#include "mtmd.h"
#include "mtmd-helper.h"
#include "ggml-cpp.h"
// TODO: tmp until the mtmd draft processing is refactored [TAG_MTMD_DRAFT_PROCESSING]
#include "../../src/llama-ext.h"
#include <algorithm>
#include <cstddef>
#include <cinttypes>
@@ -884,7 +888,7 @@ private:
has_draft ? "draft model" : "MTP context",
total / (1024.0 * 1024.0));
} catch (const std::exception & e) {
SRV_ERR("[spec] failed to measure %s memory: %s\n",
SRV_WRN("[spec] failed to measure %s memory: %s\n",
has_draft ? "draft model" : "MTP context", e.what());
}
}
@@ -940,16 +944,17 @@ private:
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
// note: for small models maybe we can set this to the maximum possible draft from all speculative types
// the extra memory for small models is likely negligible?
cparams.n_rs_seq = 0;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
cparams.n_rs_seq = 0;
cparams.ctx_other = ctx_tgt;
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
@@ -964,6 +969,7 @@ private:
cparams_mtp.type_v = params_base.speculative.draft.cache_type_v;
cparams_mtp.n_rs_seq = 0;
cparams_mtp.n_outputs_max = params_base.n_parallel;
cparams_mtp.ctx_other = ctx_tgt;
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {
@@ -971,8 +977,6 @@ private:
return false;
}
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
params_base.speculative.draft.ctx_tgt = ctx_tgt;
params_base.speculative.draft.ctx_dft = ctx_dft.get();
}
@@ -1060,6 +1064,10 @@ private:
}
}
if (ctx_dft) {
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
}
if (spec) {
SRV_INF("%s", "speculative decoding context initialized\n");
} else {
@@ -1239,6 +1247,7 @@ private:
/* tmpls */ std::move(chat_templates),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* allow_video */ mctx ? mtmd_helper_support_video(mctx) : false,
/* enable_thinking */ enable_thinking,
/* reasoning_budget */ params_base.sampling.reasoning_budget_tokens,
/* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message,
@@ -2974,10 +2983,11 @@ private:
continue;
}
if (ctx_dft) {
if (ctx_dft && llama_get_ctx_other(ctx_dft.get()) != ctx_tgt) {
// TODO: in the future, figure out how to infuse target embeddings to the images
// for now, we skip this for simplicity
// maybe we simply need to call `common_speculative_process()` on the mtmd batches in the `process_chunk` above?
// [TAG_MTMD_DRAFT_PROCESSING]
res = input_tokens.process_chunk(ctx_dft.get(), mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out);
if (res != 0) {
GGML_ABORT("failed to process multi-modal data on draft context\n");
@@ -3577,6 +3587,7 @@ server_context_meta server_context::get_meta() const {
/* has_mtmd */ impl->mctx != nullptr,
/* has_inp_image */ impl->chat_params.allow_image,
/* has_inp_audio */ impl->chat_params.allow_audio,
/* has_inp_video */ impl->chat_params.allow_video,
/* json_ui_settings */ impl->json_ui_settings,
/* json_webui_settings */ impl->json_webui_settings, // Deprecated
/* slot_n_ctx */ impl->get_slot_n_ctx(),
@@ -4174,6 +4185,7 @@ void server_routes::init_routes() {
{ "model_path", meta->model_path },
{ "modalities", json {
{"vision", meta->has_inp_image},
{"video", meta->has_inp_video},
{"audio", meta->has_inp_audio},
} },
{ "media_marker", get_media_marker() },
@@ -4967,7 +4979,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_count_tokens(const l
n_tokens = tokenize_mixed(vocab, prompt, true, true).size();
}
json response = {{"input_tokens", static_cast<int>(n_tokens)}};
json response = {{"input_tokens", static_cast<int64_t>(n_tokens)}};
if (is_oai) {
response["object"] = "response.input_tokens";
}
+1
View File
@@ -21,6 +21,7 @@ struct server_context_meta {
bool has_mtmd;
bool has_inp_image;
bool has_inp_audio;
bool has_inp_video;
json json_ui_settings; // Primary: new name
json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat)
int slot_n_ctx;
+4 -1
View File
@@ -605,7 +605,7 @@ task_params server_task::params_from_json_cmpl(
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
params.sampling.samplers = common_sampler_types_from_names(*samplers, false);
params.sampling.samplers = common_sampler_types_from_names(*samplers);
} else if (samplers->is_string()){
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
}
@@ -1393,6 +1393,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
//
void server_task_result_cmpl_partial::update(task_result_state & state) {
is_updated = true;
if (is_begin) {
return; // begin marker only flushes headers, skip parsing
}
state.update_chat_msg(content, true, oaicompat_msg_diffs);
// Copy current state for use in to_json_*() (reflects state BEFORE this chunk)