Compare commits

...

25 Commits

Author SHA1 Message Date
Georgi Gerganov 37964f44f9 mtmd : fix padding of n_tokens (#19930) 2026-02-26 18:39:49 +02:00
Georgi Gerganov 01cd448b8c server : fix ctx checkpoint restore logic (#19924) 2026-02-26 18:20:16 +02:00
Georgi Gerganov 99bd67c9b2 kv-cache : fix can_shift() check to take into account M-RoPE (#19928) 2026-02-26 18:08:54 +02:00
Aman Gupta b68d75165a llama: Add option to merge gate and exp weights (#19139)
* llama: Add option to merge gate and exp weights

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* update constants.py

* add gate_up for the all MoE models

* convert: simplify merge tensor condition

* update constants.py

* reduce number of models, add create_tensor_gate_up helper

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-26 21:01:08 +08:00
Kevin Pouget ffaafde16f ggml-virtgpu: improve the reliability of the code (#19846)
* ggml-virtgpu-backend: validate the consistency of the received objects

This patch adds consistency checks in the
ggml-virtgpu-backend (running on the host side) to ensure that the
data received from the guest is consistent (valid pointers, valid
sizes and offsets).

* ggml-virtgpu-backend: add fallback/skips for optional ggml backend methods

```
  1. bck->iface.synchronize(bck)
  2. buft->iface.get_alloc_size(buft, op)
  3. buft->iface.get_max_size(buft)
```

these three methods are optional in the GGML interface. `get_max_size`
was already properly defaulted, but `backend sychronize` and `butf
get_max_size` would have segfaulted the backend if not implemented.

* ggml-virtgpu-backend: fix log format missing argument

* ggml-virtgpu-backend: improve the abort message

* ggml-virtgpu-backend: more safety checks

* ggml-virtgpu-backend: new error code

* ggml-virtgpu-backend: initialize all the error codes

* ggml-virtgpu: add a missing comment generated by the code generator

* ggml-virtgpu: add the '[virtgpu]' prefix to the device/buffer names

* ggml-virtgpu: apir_device_buffer_from_ptr: improve the error message

* ggml-virtgpu: shared: make it match the latest api_remoting.h of Virglrenderer APIR

(still unmerged)

* ggml-virtgpu: update the code generator to have dispatch_command_name in a host/guest shared file

* ggml-virtgpu: REMOTE_CALL: fail if the backend returns an error

* docs/backend/VirtGPU.md: indicate that the RAM+VRAM size is limed to 64 GB with libkrun

* ggml-virtgpu: turn off clang-format header ordering for some of the files

Compilation breaks when ordered alphabetically.

* ggml-virtgpu: clang-format

* ggml-virtgpu/backend/shared/api_remoting: better comments for the APIR return codes
2026-02-26 20:00:57 +08:00
drrros efba35a860 server: fix load-on-startup not respected in ini file (#19897)
Co-authored-by: Roman Marchenko <r.marchenko@ideco.ru>
2026-02-26 12:32:31 +01:00
Eric Zhang 9b62913b40 jinja : correct default size for string slices (#19913) 2026-02-26 12:28:09 +01:00
Maximilian Werk 66287bdaac model : add Jina Embeddings v5 Nano (partial EuroBERT) support (#19826)
* WIP: Add EuroBERT support with autoformatting changes

This commit includes:
- EuroBERT model implementation for GGUF conversion
- C++ backend support for EuroBERT architecture
- Unintended autoformatting changes to Python files

Saving before reverting formatting-only changes.

* feat: add back eos assert when not last token pooling

* feat: removed duplicated code and cleanup

* feat: removed not working architectures and unnecessary check

* fix: typo

* fix: dynamic pooling config

* feat: added an example model for eurobert

* feat: proper llama-vocab implementation for jina-v5

* fix: removed unnecessary comments
2026-02-26 12:14:09 +01:00
Georgi Gerganov 1ca3d1de15 gguf : avoid too many file size calls (#19919) 2026-02-26 12:46:32 +02:00
yggdrasil75 bd72300591 server : fix typo in server README.md (#19900)
fix typo
2026-02-26 11:26:16 +01:00
Neo Zhang 2943210c1e support permuted, remove check s0/s10 (#19889)
Co-authored-by: Neo Zhang Jianyu <jianyu.zhang@intel.com>
2026-02-26 10:27:20 +08:00
Jeff Bolz 3769fe6eb7 vulkan: check for memory overlap before doing fusion (#19768)
* vulkan: check for memory overlap before doing fusion

* Update ggml/src/ggml-vulkan/ggml-vulkan.cpp

* address feedback
2026-02-25 18:25:38 +01:00
ddh0 832aa94762 common : add more aliases for sampler CLI params (#19797)
* common : add more aliases for sampler CLI params
2026-02-25 16:34:25 +01:00
Slobodan Josic 3af34b9ff5 ci : update the ROCm/HIP toolchain versions [no ci] (#19891)
* [HIP] Update ROCm build container to rocm/dev-ubuntu-22.04:7.2 and HIP_SDK to 26.Q1

* revert container version

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-25 15:54:49 +01:00
Georgi Gerganov f20469d919 server : enable multi-modal prompt caching (#19877) 2026-02-25 15:15:42 +02:00
Georgi Gerganov d7d826b3c1 server : support multi-modal context checkpoints (#19849)
* Modify llama-memory-hybrid-iswa.cpp

* Modify llama-memory-recurrent.cpp

* Modify server-common.cpp

* Modify server-common.h

* Modify server-context.cpp

* Modify server-task.h

* Added comment to llama-memory-hybrid-iswa.cpp

* Remove comment from server-context.cpp

* Stylistic fix server-context.cpp

* Fix an issue when seqrm isn't called in server-context.cpp

* cont : alternative impl

* cont : cleanup

* cont : n_tokens -> int64_t

---------

Co-authored-by: timkhronos <timkhronos@gmail.com>
2026-02-25 15:14:27 +02:00
Xuan-Son Nguyen c747294b2d scripts: update corpus of compare-logprobs (#19326)
* scripts: update corpus of compare-logprobs

* fix
2026-02-25 12:57:34 +01:00
Mario Limonciello 8fdf269dad ci : update Windows ROCm build to 26.Q1 [no ci] (#19810)
* Update build command to build llama-* tools not just ggml-hip
* Update rocWMMA headers to 7.2
* Add GFX1150 target
* Correct library paths for AMD libraries in 26.Q1
2026-02-25 12:30:19 +01:00
Aldehir Rojas a96a1120b4 gguf : fix ftell/fseek for Windows (#19870) 2026-02-25 06:58:11 +02:00
Georgi Gerganov 244641955f models : fix graph splits (#19866) 2026-02-25 00:01:13 +02:00
Pascal 47eb12b953 server: fix query params lost when proxying requests in multi-model router mode (#19854)
* server: fix query params lost when proxying requests in multi-model router mode

* server: re-encode query params using httplib::encode_query_component in proxy
2026-02-24 21:46:06 +01:00
Georgi Gerganov 418dea39ce ggml/gguf : prevent integer overflows (#19856)
* gguf : prevent integer overflow for ggml_context mem size

* ggml : fix int overflows in ggml_new_object()

* gguf : prevent string exhaustion

* gguf : prevent array elements exhaustion

* ggml : fix negative tensor type oob

* py : assert that alignment is non-zero power of 2

* ggml : check int overflow in ggml_new_tensor_impl and ggml_new_object

* gguf-py : error on duplicate keys when reading

* py : restore tensor_fields

* enforce proper alignment in add_custom_alignment

* gguf : better name

* gguf : fix ctx size for no_alloc == true

* gguf : minor print fix

* ggml : print values when overflow

* ggml : remove deprecated ggml_type_sizef()

* ggml : relax ggml_type asserts to debug-only

* gguf : add mem_size overflow test

* gguf : add file size check for arrays

* ggml : relax asseerts for ggml_get_type_traits()

* flake8 fix

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
2026-02-24 20:17:11 +02:00
Tarek Dakhran da426cb250 model : update label for LFM2-24B-A2B (#19848)
* model : Update label for LFM2-24B-A2B

```
❯ build/bin/llama-bench -m /data/playground/checkpoints/LFM2-24B-A2B-Preview-Q4_0.gguf,/data/playground/checkpoints/LFM2-8B-A1B-Q4_0.gguf -p 1 -n 0
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| lfm2moe 24B.A2B Q4_0           |  12.54 GiB |    23.84 B | CPU        |      10 |             pp1 |         30.35 ± 2.49 |
| lfm2moe 8B.A1B Q4_0            |   4.41 GiB |     8.34 B | CPU        |      10 |             pp1 |         49.24 ± 1.93 |
```

* Remove extra line
2026-02-24 14:27:42 +01:00
Radoslav Gerganov c830f99cfa server : support max_completion_tokens request property (#19831)
"max_tokens" is deprectated in favor of "max_completion_tokens" which
sets the upper bound for reasoning+output token.

Closes: #13700
2026-02-24 10:30:00 +02:00
Ruben Ortlam aa6f918c1c Vulkan Scalar Flash Attention Refactor (#19625)
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
2026-02-24 08:35:48 +01:00
83 changed files with 2216 additions and 1007 deletions
@@ -11,5 +11,5 @@ runs:
- name: Setup ROCm
uses: ./.github/actions/install-exe
with:
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe
url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-Win11-For-HIP.exe
args: -install
+1 -1
View File
@@ -68,7 +68,7 @@ jobs:
env:
# Make sure this is in sync with build.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
steps:
- name: Clone
+3 -5
View File
@@ -1175,10 +1175,8 @@ jobs:
runs-on: windows-2022
env:
# The ROCm version must correspond to the version used in the HIP SDK.
ROCM_VERSION: "6.4.2"
# Make sure this is in sync with build-cache.yml
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
steps:
- name: Clone
@@ -1188,7 +1186,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -1231,7 +1229,7 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
-DCMAKE_BUILD_TYPE=Release `
-DLLAMA_BUILD_BORINGSSL=ON `
-DROCM_DIR="${env:HIP_PATH}" `
+8 -8
View File
@@ -616,13 +616,13 @@ jobs:
runs-on: windows-2022
env:
HIPSDK_INSTALLER_VERSION: "25.Q3"
HIPSDK_INSTALLER_VERSION: "26.Q1"
strategy:
matrix:
include:
- name: "radeon"
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
gpu_targets: "gfx1150;gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
steps:
- name: Clone
@@ -632,7 +632,7 @@ jobs:
- name: Grab rocWMMA package
id: grab_rocwmma
run: |
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb"
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
7z x rocwmma.deb
7z x data.tar
@@ -655,7 +655,7 @@ jobs:
run: |
$ErrorActionPreference = "Stop"
write-host "Downloading AMD HIP SDK Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-Win11-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP SDK"
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
$completed = $proc.WaitForExit(600000)
@@ -689,20 +689,20 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
-DGGML_CPU=OFF `
-DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGPU_TARGETS="${{ matrix.gpu_targets }}" `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_HIP=ON `
-DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
md "build\bin\rocblas\library\"
md "build\bin\hipblaslt\library"
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\libhipblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\libhipblaslt.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\"
+3 -3
View File
@@ -1578,7 +1578,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--temp"}, "N",
{"--temp", "--temperature"}, "N",
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
[](common_params & params, const std::string & value) {
params.sampling.temp = std::stof(value);
@@ -1611,7 +1611,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--top-nsigma"}, "N",
{"--top-nsigma", "--top-n-sigma"}, "N",
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
[](common_params & params, const std::string & value) {
params.sampling.top_n_sigma = std::stof(value);
@@ -1634,7 +1634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--typical"}, "N",
{"--typical", "--typical-p"}, "N",
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
[](common_params & params, const std::string & value) {
params.sampling.typ_p = std::stof(value);
+2
View File
@@ -721,6 +721,8 @@ value member_expression::execute_impl(context & ctx) {
int64_t arr_size = 0;
if (is_val<value_array>(object)) {
arr_size = object->as_array().size();
} else if (is_val<value_string>(object)) {
arr_size = object->as_string().length();
}
if (is_stmt<slice_expression>(this->property)) {
+66 -4
View File
@@ -116,7 +116,8 @@ class ModelBase:
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
disable_mistral_community_chat_template: bool = False,
sentence_transformers_dense_modules: bool = False):
sentence_transformers_dense_modules: bool = False,
fuse_gate_up_exps: bool = False):
if type(self) is ModelBase or \
type(self) is TextModel or \
type(self) is MmprojModel:
@@ -135,6 +136,9 @@ class ModelBase:
self.dry_run = dry_run
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.fuse_gate_up_exps = fuse_gate_up_exps
self._gate_exp_buffer: dict[int, Tensor] = {}
self._up_exp_buffer: dict[int, Tensor] = {}
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
@@ -512,8 +516,31 @@ class ModelBase:
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
return [(self.map_tensor_name(name), data_torch)]
new_name = self.map_tensor_name(name)
# Handle gate/up expert tensor fusion if enabled
if self.fuse_gate_up_exps and bid is not None:
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
self._gate_exp_buffer[bid] = data_torch
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
self._up_exp_buffer[bid] = data_torch
# Check if both gate and up are buffered for this layer
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
gate_data = self._gate_exp_buffer.pop(bid)
up_data = self._up_exp_buffer.pop(bid)
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
fused_data = torch.cat([gate_data, up_data], dim=1)
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
return [(fused_name, fused_data)]
# If we buffered a gate/up tensor, wait for the other
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
return []
return [(new_name, data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
@@ -1148,6 +1175,9 @@ class TextModel(ModelBase):
if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de
res = "jina-v2-de"
if chkhsh == "a023e9fdc5a11f034d3ef515b92350e56fb2af1f66c6b6811a4444ea9bf8763d":
# ref: https://huggingface.co/jinaai/jina-embeddings-v5-text-nano
res = "jina-v5-nano"
if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d":
# ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct
res = "smaug-bpe"
@@ -6125,6 +6155,32 @@ class NeoBert(BertModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("EuroBertModel", "JinaEmbeddingsV5Model")
class EuroBertModel(TextModel):
model_arch = gguf.MODEL_ARCH.EUROBERT
def set_vocab(self):
self.gguf_writer.add_add_bos_token(False)
self._set_vocab_gpt2()
def set_gguf_parameters(self):
super().set_gguf_parameters()
# EuroBert is bidirectional (encoder)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Strip "model." prefix from tensor names
if name.startswith("model."):
name = name[6:]
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
@@ -11913,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
"Default these modules are not included.")
)
parser.add_argument(
"--fuse-gate-up-exps", action="store_true",
help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
)
args = parser.parse_args()
if not args.print_supported_models and args.model is None:
parser.error("the following arguments are required: model")
@@ -12050,7 +12111,8 @@ def main() -> None:
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
fuse_gate_up_exps=args.fuse_gate_up_exps
)
if args.vocab_only:
+1
View File
@@ -107,6 +107,7 @@ models = [
{"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM!
{"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", },
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", },
{"name": "jina-v5-nano", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v5-text-nano", },
{"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", },
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
+3 -1
View File
@@ -152,7 +152,9 @@ Commands and data are serialized using a custom binary protocol with:
- **VM-specific**: Only works in virtual machines with virtio-gpu support
- **Host dependency**: Requires properly configured host-side backend
- **Latency**: Small overhead from VM escaping for each operation
- **Shared-memory size**: with the `libkrun` hypervisor, the RAM + VRAM
addressable memory is limited to 64 GB. So the maximum GPU memory
will be `64GB - RAM`, regardless of the hardware VRAM size.
* This work is pending upstream changes in the VirglRenderer
project.
-4
View File
@@ -730,10 +730,6 @@ extern "C" {
GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
GGML_DEPRECATED(
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
"use ggml_row_size() instead");
GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);
+21 -20
View File
@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
int s00, int s01, int s02, int s03,
int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
}
}
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13,
int s00, int s01, int s02, int s03,
int s10, int s11, int s12, int s13,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
}
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
queue_ptr stream) {
int nr0 = ne10 / ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
cnb[3] *= cne[3];
};
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
// size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
s03, s10, s11, s12, s13, item_ct1);
});
}
} else {
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,
item_ct1);
});
}
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
main_stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -7,9 +7,21 @@
#include <cstdint>
static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) {
if (cgraph_size == 0) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation);
return 1;
}
// place-holder: validate that the size of shmem_res_id is <= cgraph_size
// need to add another method in the Virgl->APIR callback interface
GGML_UNUSED(shmem_res_id);
return 0; // Valid
}
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
GGML_UNUSED(enc);
static bool async_backend_initialized = false;
static bool async_backend;
@@ -34,10 +46,26 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
size_t cgraph_size;
apir_decode_size_t(dec, &cgraph_size);
if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) {
apir_decoder_set_fatal(dec);
return 1;
}
apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__);
return 1;
}
if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__,
cgraph->n_nodes, cgraph->n_leafs);
return 1;
}
ggml_status status;
#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
for (int idx = 0; idx < cgraph->n_nodes; idx++) {
@@ -45,7 +73,8 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
if (dev->iface.supports_op(dev, op)) {
continue;
}
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op));
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx,
ggml_op_desc(op));
status = GGML_STATUS_ABORTED;
apir_encode_ggml_status(enc, &status);
@@ -53,9 +82,17 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v
return 0;
}
#endif
// Check if backend is properly initialized
if (!bck) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__);
return 1;
}
status = bck->iface.graph_compute(bck, cgraph);
if (async_backend) {
if (async_backend && bck->iface.synchronize) {
bck->iface.synchronize(bck);
}
@@ -85,7 +85,19 @@ uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * d
const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
size_t value = buft->iface.get_alloc_size(buft, op);
// Check for decode error
if (op == nullptr) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__);
apir_decoder_set_fatal(dec);
return 1;
}
size_t value;
if (buft->iface.get_alloc_size) {
value = buft->iface.get_alloc_size(buft, op);
} else {
value = ggml_nbytes(op); // Default fallback
}
apir_encode_size_t(enc, &value);
@@ -6,11 +6,26 @@
#include <cstdint>
static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) {
// Only check for critical integer overflow - no arbitrary size limits
if (offset > SIZE_MAX - size) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size);
return 1;
}
return 0; // Valid
}
uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
GGML_UNUSED(ctx);
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
apir_encode_uintptr_t(enc, &base);
@@ -24,6 +39,11 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
@@ -37,6 +57,10 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl
size_t size;
apir_decode_size_t(dec, &size);
if (validate_buffer_operation(offset, size, __func__) != 0) {
return 1;
}
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
@@ -56,6 +80,11 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
const ggml_tensor * tensor;
// safe to remove the const qualifier here
tensor = apir_decode_ggml_tensor(dec);
@@ -69,6 +98,10 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl
size_t size;
apir_decode_size_t(dec, &size);
if (validate_buffer_operation(offset, size, __func__) != 0) {
return 1;
}
void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
if (!shmem_data) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
@@ -86,6 +119,11 @@ uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
const ggml_tensor * src;
// safe to remove the const qualifier here
src = apir_decode_ggml_tensor(dec);
@@ -105,6 +143,11 @@ uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
uint8_t value;
apir_decode_uint8_t(dec, &value);
@@ -120,6 +163,11 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
ggml_backend_buffer_t buffer;
buffer = apir_decode_ggml_buffer(dec);
if (!buffer || apir_decoder_get_fatal(dec)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
return 1;
}
if (!apir_untrack_backend_buffer(buffer)) {
GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer);
return 1;
@@ -1,6 +1,6 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "backend-virgl-apir.h"
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-impl.h"
@@ -28,19 +28,24 @@ uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
}
if (!reg->iface.get_device_count(reg)) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__);
size_t device_count = reg->iface.get_device_count(reg);
if (!device_count) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
dev = reg->iface.get_device(reg, 0);
if (!dev) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__);
return APIR_BACKEND_INITIALIZE_NO_DEVICE;
}
bck = dev->iface.init_backend(dev, NULL);
if (!bck) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__);
return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED;
}
return APIR_BACKEND_INITIALIZE_SUCCESS;
}
@@ -32,64 +32,6 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg
/* backend */
uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) {
switch (type) {
/* device */
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
return "backend_device_get_device_count";
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
return "backend_device_get_count";
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
return "backend_device_get_name";
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
return "backend_device_get_description";
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
return "backend_device_get_type";
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
return "backend_device_get_memory";
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
return "backend_device_supports_op";
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
return "backend_device_get_buffer_type";
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
return "backend_device_get_props";
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
return "backend_device_buffer_from_ptr";
/* buffer-type */
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
return "backend_buffer_type_get_name";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
return "backend_buffer_type_get_alignment";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
return "backend_buffer_type_get_max_size";
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
return "backend_buffer_type_is_host (DEPRECATED)";
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
return "backend_buffer_type_alloc_buffer";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
return "backend_buffer_type_get_alloc_size";
/* buffer */
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
return "backend_buffer_get_base";
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
return "backend_buffer_set_tensor";
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
return "backend_buffer_get_tensor";
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
return "backend_buffer_cpy_tensor";
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
return "backend_buffer_clear";
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
return "backend_buffer_free_buffer";
/* backend */
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
return "backend_backend_graph_compute";
default:
return "unknown";
}
}
extern "C" {
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {
@@ -1,5 +1,6 @@
#pragma once
// clang-format off
#include <cstdint>
#include <cstddef>
@@ -10,6 +11,7 @@
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
#include "shared/apir_cs_ggml.h"
// clang-format on
#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: "
@@ -19,7 +19,7 @@ struct virgl_apir_callbacks {
};
extern "C" {
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs);
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs);
void apir_backend_deinit(uint32_t virgl_ctx_id);
uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_callbacks * virgl_cbs,
+15 -23
View File
@@ -1,6 +1,5 @@
#include "backend-dispatched.h"
#include "backend-virgl-apir.h"
#include "shared/api_remoting.h"
#include "shared/apir_backend.h"
#include "shared/apir_cs.h"
@@ -17,10 +16,10 @@
#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
static void * backend_library_handle = NULL;
static FILE * apir_logfile = NULL;
static FILE * apir_logfile = NULL;
static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
FILE * logfile = (FILE *)user_data;
FILE * logfile = (FILE *) user_data;
fprintf(logfile, "[%d] %s", level, text);
fflush(logfile);
}
@@ -48,9 +47,9 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) {
}
#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg"
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) {
ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) {
const char * dlsym_error;
const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
@@ -63,15 +62,13 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
}
}
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
if (!library_name) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot open the GGML library: env var '%s' not defined\n",
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__,
APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
@@ -79,16 +76,14 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
backend_library_handle = dlopen(library_name, RTLD_LAZY);
if (!backend_library_handle) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot open the GGML library: %s\n", __func__, dlerror());
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror());
return APIR_LOAD_LIBRARY_CANNOT_OPEN;
}
if (!library_reg) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot register the GGML library: env var '%s' not defined\n",
__func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__,
APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
}
@@ -96,11 +91,9 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct
void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
dlsym_error = dlerror();
if (dlsym_error) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
__func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
}
@@ -132,13 +125,12 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id,
virgl_apir_context ctx = {
.ctx_id = virgl_ctx_id,
.iface = virgl_cbs,
.iface = virgl_cbs,
};
if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
GGML_LOG_ERROR(GGML_VIRTGPU_BCK
"%s: Received an invalid dispatch index (%d >= %d)\n",
__func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT);
GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type,
APIR_BACKEND_DISPATCH_TABLE_COUNT);
return APIR_BACKEND_FORWARD_INDEX_INVALID;
}
@@ -16,28 +16,32 @@ enum ApirCommandType {
APIR_COMMAND_TYPE_LOADLIBRARY = 1,
APIR_COMMAND_TYPE_FORWARD = 2,
APIR_COMMAND_TYPE_LENGTH = 3,
APIR_COMMAND_TYPE_LENGTH = 3,
};
typedef uint64_t ApirCommandFlags;
enum ApirLoadLibraryReturnCode {
APIR_LOAD_LIBRARY_SUCCESS = 0,
// these error codes are returned by the Virglrenderer APIR component
APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
APIR_LOAD_LIBRARY_ALREADY_LOADED = 2,
APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3,
APIR_LOAD_LIBRARY_CANNOT_OPEN = 4,
APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5,
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code
// any value greater than this is an APIR *backend library* initialization return code
APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6,
};
enum ApirForwardReturnCode {
APIR_FORWARD_SUCCESS = 0,
APIR_FORWARD_NO_DISPATCH_FCT = 1,
APIR_FORWARD_TIMEOUT = 2,
APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code
} ;
APIR_FORWARD_SUCCESS = 0,
// these error codes are returned by the Virglrenderer APIR component
APIR_FORWARD_NO_DISPATCH_FCT = 1,
APIR_FORWARD_TIMEOUT = 2,
APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3,
// any value greater than this index an APIR *backend library* forward return code
APIR_FORWARD_BASE_INDEX = 4,
};
__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
switch (type) {
@@ -82,6 +86,7 @@ __attribute__((unused)) static const char * apir_forward_error(ApirForwardReturn
APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS);
APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
return "Unknown APIR_COMMAND_TYPE_FORWARD error";
@@ -34,3 +34,61 @@ typedef enum ApirBackendCommandType {
// last command_type index + 1
APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
} ApirBackendCommandType;
static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {
switch (type) {
/* device */
case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
return "device_get_device_count";
case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
return "device_get_count";
case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
return "device_get_name";
case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
return "device_get_description";
case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
return "device_get_type";
case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
return "device_get_memory";
case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
return "device_supports_op";
case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
return "device_get_buffer_type";
case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
return "device_get_props";
case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
return "device_buffer_from_ptr";
/* buffer-type */
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
return "buffer_type_get_name";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
return "buffer_type_get_alignment";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
return "buffer_type_get_max_size";
case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
return "buffer_type_is_host";
case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
return "buffer_type_alloc_buffer";
case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
return "buffer_type_get_alloc_size";
/* buffer */
case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
return "buffer_get_base";
case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
return "buffer_set_tensor";
case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
return "buffer_get_tensor";
case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
return "buffer_cpy_tensor";
case APIR_COMMAND_TYPE_BUFFER_CLEAR:
return "buffer_clear";
case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
return "buffer_free_buffer";
/* backend */
case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
return "backend_graph_compute";
default:
return "unknown";
}
}
@@ -14,7 +14,7 @@
#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6
#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7
#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8
#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9
// new entries here need to be added to the apir_backend_initialize_error function below
@@ -39,6 +39,10 @@ static const char * apir_backend_initialize_error(int code) {
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE);
APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED);
return "Unknown APIR_BACKEND_INITIALIZE error:/";
+7 -13
View File
@@ -13,7 +13,6 @@ struct apir_encoder {
const char * start;
const char * end;
bool fatal;
};
struct apir_decoder {
@@ -28,8 +27,8 @@ struct apir_decoder {
static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
apir_decoder dec = {
.cur = ptr,
.end = ptr + size,
.cur = ptr,
.end = ptr + size,
.fatal = false,
};
@@ -79,10 +78,7 @@ static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
* encode peek
*/
static inline bool apir_decoder_peek_internal(apir_decoder * dec,
size_t size,
void * val,
size_t val_size) {
static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) {
assert(val_size <= size);
if (unlikely(size > (size_t) (dec->end - dec->cur))) {
@@ -332,8 +328,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t
static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
size_t alloc_size;
if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n",
__func__, size, count);
GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count);
return NULL;
}
@@ -352,20 +347,19 @@ static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
/* apir_buffer_type_host_handle_t */
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc,
const apir_buffer_type_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec,
apir_buffer_type_host_handle_t * val) {
apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
}
/* apir_buffer_host_handle_t */
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc,
const apir_buffer_host_handle_t * val) {
static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) {
apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
}
@@ -1,11 +1,10 @@
#include "ggml-impl.h"
#include "apir_cs.h"
#include "apir_cs_rpc.h"
#include "ggml-impl.h"
// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc,
const apir_buffer_host_handle_t * handle);
static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle);
static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
@@ -22,8 +21,7 @@ static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
}
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec,
uint32_t n_tensors) {
static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) {
size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
@@ -45,9 +43,9 @@ static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
}
ggml_init_params params{
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
/*.mem_size =*/ggml_tensor_overhead(),
/*.mem_buffer =*/NULL,
/*.no_alloc =*/true,
};
ggml_context * ctx = ggml_init(params);
@@ -105,6 +103,19 @@ static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec)
apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
// SECURITY: Validate buffer handle against tracked buffers to prevent
// guest VM from providing arbitrary host memory addresses
if (buffer) {
extern std::unordered_set<ggml_backend_buffer_t> backend_buffers;
if (backend_buffers.find(buffer) == backend_buffers.end()) {
GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__,
(void *) buffer);
// Set fatal flag to prevent further processing with invalid handle
apir_decoder_set_fatal(dec);
return NULL;
}
}
return buffer;
}
@@ -1,3 +1,6 @@
#pragma once
// clang-format off
#include "ggml.h"
#include "ggml-backend-impl.h"
@@ -5,6 +8,7 @@
#include <unordered_set>
#include <vector>
#include <cstdint>
// clang-format on
// ggml_tensor is serialized into apir_rpc_tensor
struct apir_rpc_tensor {
@@ -34,6 +34,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml
static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
virtgpu * gpu = BUFT_TO_GPU(buft);
// Return the prefixed name that was built once during initialization
return gpu->cached_buffer_type.name;
}
@@ -53,9 +54,8 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff
const ggml_tensor * tensor) {
virtgpu * gpu = BUFT_TO_GPU(buft);
if (tensor->buffer == NULL
|| !tensor->buffer->context
|| !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
if (tensor->buffer == NULL || !tensor->buffer->context ||
!buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
return ggml_nbytes(tensor);
}
@@ -3,6 +3,7 @@
static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
// Return the prefixed name that was built once during initialization
return gpu->cached_device_info.name;
}
@@ -22,7 +23,7 @@ static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_bac
static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
virtgpu * gpu = DEV_TO_GPU(dev);
*free = gpu->cached_device_info.memory_free;
*free = gpu->cached_device_info.memory_free;
*total = gpu->cached_device_info.memory_total;
}
@@ -72,7 +73,7 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_
ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
static std::atomic<bool> initialized = false;
static std::atomic<bool> initialized = false;
static ggml_backend_buffer_type buft;
if (!initialized) {
@@ -95,7 +96,7 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac
static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
virtgpu * gpu = DEV_TO_GPU(dev);
static std::atomic<bool> initialized = false;
static std::atomic<bool> initialized = false;
static ggml_backend_buffer_type buft;
if (!initialized) {
+40 -16
View File
@@ -7,8 +7,8 @@
void ggml_virtgpu_cleanup(virtgpu * gpu);
static virtgpu * apir_initialize() {
static virtgpu * gpu = NULL;
static std::atomic<bool> initialized = false;
static virtgpu * gpu = NULL;
static std::atomic<bool> initialized = false;
if (initialized) {
// fast track
@@ -31,29 +31,53 @@ static virtgpu * apir_initialize() {
}
// Pre-fetch and cache all device information, it will not change
gpu->cached_device_info.description = apir_device_get_description(gpu);
gpu->cached_device_info.description = apir_device_get_description(gpu);
if (!gpu->cached_device_info.description) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
}
gpu->cached_device_info.name = apir_device_get_name(gpu);
if (!gpu->cached_device_info.name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__);
}
gpu->cached_device_info.device_count = apir_device_get_count(gpu);
gpu->cached_device_info.type = apir_device_get_type(gpu);
apir_device_get_memory(gpu,
&gpu->cached_device_info.memory_free,
&gpu->cached_device_info.memory_total);
{
// Get the remote name and create prefixed version
char * rmt_device_name = apir_device_get_name(gpu);
if (!rmt_device_name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__);
}
size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator
gpu->cached_device_info.name = (char *) malloc(device_name_len);
if (!gpu->cached_device_info.name) {
free(rmt_device_name);
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__);
}
snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name);
free(rmt_device_name);
}
apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total);
apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
gpu->cached_buffer_type.host_handle = buft_host_handle;
gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle);
if (!gpu->cached_buffer_type.name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__);
{
// Get the remote name and create prefixed version
char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle);
if (!rmt_name) {
GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__);
}
size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator
gpu->cached_buffer_type.name = (char *) malloc(prefixed_len);
if (!gpu->cached_buffer_type.name) {
free(rmt_name);
GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__);
}
snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name);
free(rmt_name);
}
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle);
initialized = true;
}
@@ -98,7 +122,7 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
static std::atomic<bool> initialized = false;
if (initialized) {
return; // fast track
return; // fast track
}
{
+1 -1
View File
@@ -1,5 +1,5 @@
#include "ggml-remoting.h"
#include "../../include/ggml-virtgpu.h"
#include "ggml-remoting.h"
static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) {
UNUSED(backend);
+1 -1
View File
@@ -9,7 +9,7 @@
#include <string>
#define GGML_VIRTGPU_NAME "ggml-virtgpu"
#define GGML_VIRTGPU "ggml-virtgpu: "
#define GGML_VIRTGPU "ggml-virtgpu: "
// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes
+3 -3
View File
@@ -3,7 +3,7 @@
#include <stdint.h>
struct virgl_renderer_capset_apir {
uint32_t apir_version;
uint32_t supports_blob_resources;
uint32_t reserved[4]; // For future expansion
uint32_t apir_version;
uint32_t supports_blob_resources;
uint32_t reserved[4]; // For future expansion
};
+24 -23
View File
@@ -145,8 +145,31 @@ class RemotingCodebaseGenerator:
enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
enum_lines.append("} ApirBackendCommandType;")
# Generate function name mapping
func_lines = []
func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {")
func_lines.append(" switch (type) {")
current_group = None
for func in functions:
# Add comment for new group
if func['group_name'] != current_group:
func_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
# Generate clean function name without backend_ prefix
clean_name = f"{func['group_name']}_{func['function_name']}"
func_lines.append(f" case {func['enum_name']}:")
func_lines.append(f" return \"{clean_name}\";")
func_lines.append("")
func_lines.append(" default:")
func_lines.append(" return \"unknown\";")
func_lines.append(" }")
func_lines.append("}")
# Full header template
header_content = NL.join(enum_lines) + "\n"
header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n"
return header_content
@@ -170,19 +193,6 @@ class RemotingCodebaseGenerator:
decl_lines.append(f"{signature} {func['backend_function']}({params});")
# Switch cases
switch_lines = []
current_group = None
for func in functions:
if func['group_name'] != current_group:
switch_lines.append(f" /* {func['group_description']} */")
current_group = func['group_name']
deprecated = " (DEPRECATED)" if func['deprecated'] else ""
switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";")
# Dispatch table
table_lines = []
current_group = None
@@ -201,15 +211,6 @@ class RemotingCodebaseGenerator:
{NL.join(decl_lines)}
static inline const char *backend_dispatch_command_name(ApirBackendCommandType type)
{{
switch (type) {{
{NL.join(switch_lines)}
default: return "unknown";
}}
}}
extern "C" {{
static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
{NL.join(table_lines)}
@@ -17,8 +17,8 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (cgraph_size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -26,7 +26,7 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
}
@@ -62,7 +62,9 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle
return max_size;
}
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) {
apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
size_t size) {
apir_encoder * encoder;
apir_decoder * decoder;
ApirForwardReturnCode ret;
@@ -84,7 +86,9 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_t
return buffer_context;
}
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) {
size_t apir_buffer_type_get_alloc_size(virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
const ggml_tensor * op) {
apir_encoder * encoder;
apir_decoder * decoder;
ApirForwardReturnCode ret;
@@ -35,8 +35,8 @@ void apir_buffer_set_tensor(virtgpu * gpu,
apir_encode_ggml_tensor(encoder, tensor);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -44,7 +44,7 @@ void apir_buffer_set_tensor(virtgpu * gpu,
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
@@ -86,8 +86,8 @@ void apir_buffer_get_tensor(virtgpu * gpu,
apir_encode_ggml_tensor(encoder, tensor);
virtgpu_shmem temp_shmem; // Local storage for large buffers
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
virtgpu_shmem * shmem = &temp_shmem;
bool using_shared_shmem = false;
if (size <= gpu->data_shmem.mmap_size) {
// Lock mutex before using shared data_shmem buffer
@@ -95,7 +95,7 @@ void apir_buffer_get_tensor(virtgpu * gpu,
GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
}
using_shared_shmem = true;
shmem = &gpu->data_shmem;
shmem = &gpu->data_shmem;
} else if (virtgpu_shmem_create(gpu, size, shmem)) {
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
@@ -26,7 +26,7 @@ char * apir_device_get_name(virtgpu * gpu) {
REMOTE_CALL(gpu, encoder, decoder, ret);
const size_t string_size = apir_decode_array_size_unchecked(decoder);
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
if (!string) {
GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__);
return NULL;
@@ -173,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si
REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR);
if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) {
GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer");
GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size);
}
apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id);
+27 -20
View File
@@ -1,29 +1,36 @@
#include "virtgpu.h"
#pragma once
// clang-format off
#include "virtgpu.h"
#include "ggml-remoting.h"
#include "backend/shared/apir_backend.h"
#include "backend/shared/apir_cs_ggml.h"
#include "ggml-backend-impl.h"
// clang-format on
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
do { \
int32_t forward_flag = (int32_t) apir_command_type__; \
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \
if (!encoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
} \
#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \
int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \
const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \
do { \
encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \
if (!encoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \
} \
} while (0)
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
do { \
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
if (!decoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
} \
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
apir_forward_error(ret_name), ret_name); \
} \
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \
do { \
ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
if (!decoder_name) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \
} \
if (ret_name < APIR_FORWARD_BASE_INDEX) { \
GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \
apir_forward_error(ret_name), ret_name); \
} \
ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \
if (ret_name != 0) { \
GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \
REMOTE_CALL_PREPARE_command_name, ret_name); \
} \
} while (0)
@@ -20,6 +20,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu,
char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
/* apir_buffer_type_is_host is deprecated. */
apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu,
apir_buffer_type_host_handle_t host_handle,
size_t size);
+35 -50
View File
@@ -53,9 +53,9 @@ static int virtgpu_handshake(virtgpu * gpu) {
if (!decoder) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initiate the communication with the virglrenderer library. "
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
__func__);
"%s: failed to initiate the communication with the virglrenderer library. "
"Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
__func__);
return 1;
}
@@ -65,8 +65,7 @@ static int virtgpu_handshake(virtgpu * gpu) {
uint32_t host_minor;
if (ret_magic != APIR_HANDSHAKE_MAGIC) {
GGML_ABORT(GGML_VIRTGPU
"%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
apir_backend_initialize_error(ret_magic));
} else {
apir_decode_uint32_t(decoder, &host_major);
@@ -140,15 +139,13 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
"Make sure virglrenderer is correctly configured by the hypervisor. (%s) ",
__func__, apir_load_library_error(ret));
} else {
GGML_ABORT(GGML_VIRTGPU
"%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__,
apir_load_library_error(ret), ret);
GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)",
__func__, apir_load_library_error(ret), ret);
}
return ret;
}
GGML_LOG_INFO(GGML_VIRTGPU
"%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
@@ -158,10 +155,11 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
} else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) {
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
GGML_ABORT(
GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
"Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
__func__, apir_load_library_error(apir_ret));
} else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)",
@@ -169,8 +167,8 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
} else {
uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX;
GGML_ABORT(GGML_VIRTGPU
"%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__,
lib_ret);
"%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)",
__func__, lib_ret);
}
return ret;
}
@@ -184,55 +182,49 @@ virtgpu * create_virtgpu() {
// Initialize mutex to protect shared data_shmem buffer
if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {
delete gpu;
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize data_shmem mutex", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__);
return NULL;
}
if (virtgpu_open(gpu) != APIR_SUCCESS) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to open the virtgpu device\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__);
return NULL;
}
if (virtgpu_init_capset(gpu) != APIR_SUCCESS) {
if (gpu->use_apir_capset) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__);
"%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library "
"supports it.",
__func__);
} else {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the virtgpu Venus capset", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__);
}
return NULL;
}
if (virtgpu_init_context(gpu) != APIR_SUCCESS) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to initialize the GPU context", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__);
return NULL;
}
if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to create the shared reply memory pages", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__);
return NULL;
}
if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to create the shared data memory pages", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__);
return NULL;
}
if (virtgpu_handshake(gpu)) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to handshake with the virglrenderer library", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__);
return NULL;
}
if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to load the backend library", __func__);
GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__);
return NULL;
}
@@ -243,8 +235,7 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) {
drmDevicePtr devs[8];
int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs));
if (count < 0) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to enumerate DRM devices\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__);
return APIR_ERROR_INITIALIZATION_FAILED;
}
@@ -266,19 +257,17 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d
int fd = open(node_path, O_RDWR | O_CLOEXEC);
if (fd < 0) {
GGML_ABORT(GGML_VIRTGPU
"%s: failed to open %s", __func__, node_path);
GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path);
return APIR_ERROR_INITIALIZATION_FAILED;
}
drmVersionPtr version = drmGetVersion(fd);
if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) {
if (version) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name,
version->version_major);
} else {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to get DRM driver version\n", __func__);
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__);
}
if (version) {
@@ -322,9 +311,8 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) {
virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data));
if (ret) {
GGML_LOG_ERROR(GGML_VIRTGPU
"%s: failed to get APIR v%d capset: %s\n",
__func__, gpu->capset.version, strerror(errno));
GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version,
strerror(errno));
return APIR_ERROR_INITIALIZATION_FAILED;
}
@@ -547,13 +535,10 @@ static void log_call_duration(long long call_duration_ns, const char * name) {
double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds
if (call_duration_s > 1) {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %.2fs for the %s host reply...\n", call_duration_s, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name);
} else if (call_duration_ms > 1) {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %.2fms for the %s host reply...\n", call_duration_ms, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name);
} else {
GGML_LOG_INFO(GGML_VIRTGPU
"waited %lldns for the %s host reply...\n", call_duration_ns, name);
GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name);
}
}
+5 -3
View File
@@ -1,5 +1,6 @@
#pragma once
// clang-format off
#include "virtgpu-utils.h"
#include "virtgpu-shm.h"
#include "virtgpu-apir.h"
@@ -23,20 +24,21 @@
#include "apir_hw.h"
#include <drm/virtgpu_drm.h>
#include "venus_hw.h"
// clang-format on
#ifndef VIRTGPU_DRM_CAPSET_APIR
// Will be defined include/drm/virtgpu_drm.h when
// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs
// is merged
#define VIRTGPU_DRM_CAPSET_APIR 10
# define VIRTGPU_DRM_CAPSET_APIR 10
#endif
// Mesa/Virlgrenderer Venus internal. Only necessary during the
// Venus->APIR transition in Virglrenderer
#define VENUS_COMMAND_TYPE_LENGTH 331
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
#define VIRTGPU_DRM_CAPSET_VENUS 4
#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16
# define VIRTGPU_DRM_CAPSET_VENUS 4
#endif
typedef uint32_t virgl_renderer_capset;
File diff suppressed because it is too large Load Diff
@@ -3,9 +3,13 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
#endif
#extension GL_KHR_shader_subgroup_shuffle : enable
#extension GL_KHR_shader_subgroup_vote : enable
@@ -15,8 +19,10 @@
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// If SubGroupSize is set to 0 then only use shmem reductions
const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
shared float tmpsh[tmpsh_size];
shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
const uint32_t masksh_stride = Br + 1;
shared FLOAT_TYPE masksh[Bc * masksh_stride];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][HSK / 4];
const uint32_t qf_stride = HSK / 4 + 1;
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
const uint32_t D = HSK > HSV ? HSK : HSV;
const uint32_t kvsh_stride = D / 4 + 1;
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -50,8 +58,24 @@ void main() {
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
if (LIMIT_OCCUPANCY_SHMEM > 0) {
// This just exists to avoid the occupancy_limiter array getting optimized out
occupancy_limiter[tid] = vec4(tid);
barrier();
if (occupancy_limiter[tid] == vec4(99999.0)) {
data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
}
}
#define tile_row(r) (row_tid * rows_per_thread + (r))
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
@@ -60,37 +84,37 @@ void main() {
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
vec4 Of[Br][HSV_per_thread / 4];
FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = FLOAT_TYPEV4(0.0);
}
}
float Lf[Br], Mf[Br];
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
ACC_TYPE slope[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
slope[r] = ACC_TYPE(1.0);
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
}
}
@@ -113,75 +137,141 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
float max_mask = NEG_FLT_MAX_OVER_2;
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c * masksh_stride + r] = m;
max_mask = max(max_mask, float(m));
} else {
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
}
// Only load if the block is not all zeros
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
float max_mask = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = m;
max_mask = max(max_mask, m);
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = ACC_TYPE(0.0);
}
}
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
}
// More d iterations means Q register caching becomes relevant
// Few iterations means the additional registers needed are worse than the speed-up from caching
if (HSK_per_thread / 4 > 4) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 Q_cache[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
masksh[c][r] = float(0);
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
}
}
}
// skip the block if the mask is entirely -inf
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
barrier();
if (gl_SubgroupInvocationID == 0) {
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
}
barrier();
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
max_mask = max(max_mask, tmpsh[s]);
}
if (max_mask <= NEG_FLT_MAX_OVER_2) {
continue;
}
}
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
} else {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
FLOAT_TYPEV4 K_Tf;
if (SHMEM_STAGING != 0) {
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
}
}
}
}
@@ -189,89 +279,109 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
}
}
}
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
rowmaxf = max(rowmaxf, float(Sf[r][c]));
}
Moldf[r] = Mf[r];
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
}
}
if (SHMEM_STAGING != 0) {
barrier();
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV / 4);
uint32_t c = (idx + tid) / (HSV / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
barrier();
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
FLOAT_TYPE Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
FLOAT_TYPEV4 Vf;
if (SHMEM_STAGING != 0) {
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
} else {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
}
}
}
barrier();
}
// prevent race on tmpsh
@@ -279,58 +389,108 @@ void main() {
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = Mf[r];
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
}
if (row_split == 1) {
// Reduce inside workgroup with shmem
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
}
barrier();
rowmaxf = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
}
}
} else {
barrier();
tmpsh[tid] = rowmaxf;
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
}
barrier();
}
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
float eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Lf[r] += subgroupShuffleXor(Lf[r], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
}
barrier();
Lf[r] = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Lf[r] += tmpsh[s * D_split + d_tid];
}
}
} else {
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
tmpsh[tid] = Lf[r];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
if (SubGroupSize > 0) {
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
}
if (row_split == 1) {
barrier();
if (gl_SubgroupInvocationID == d_tid) {
tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
}
barrier();
Of[r][d] = tmpshv4[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Of[r][d] += tmpshv4[s * D_split + d_tid];
}
}
} else {
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
if (rowgroup_tid < s) {
Of[r][d] += tmpshv4[tid ^ s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
}
}
}
@@ -338,33 +498,53 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
}
}
if (global_row < N && d_tid == 0 && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
@@ -373,7 +553,7 @@ void main() {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
Of[r][d] *= FLOAT_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
@@ -383,39 +563,37 @@ void main() {
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif
}
}
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
if (i * Br + row < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
}
}
}
@@ -1,16 +1,18 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
layout (constant_id = 9) const uint32_t Flags = 0;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t row_split = 1;
layout (constant_id = 8) const uint32_t SubGroupSize = 32;
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
layout (constant_id = 10) const uint32_t Flags = 0;
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
@@ -69,6 +71,7 @@ layout (push_constant) uniform parameter {
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
@@ -94,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
// iqs is currently always zero in the flash attention shaders
if (binding_idx == BINDING_IDX_K) {
return k_packed.k_data_packed[a_offset + ib];
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
} else {
return v_packed.v_data_packed[a_offset + ib];
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
}
}
#endif
@@ -107,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -115,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -123,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
@@ -189,10 +192,16 @@ void init_indices()
KV = p.KV;
if (p.k_num > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
if (p.gqa_ratio > 1) {
i = 0;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else {
gqa_iq1 = 0;
split_k_index = gl_WorkGroupID.x % p.k_num;
i = gl_WorkGroupID.x / p.k_num;
}
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;
@@ -244,3 +253,11 @@ void init_indices()
// Bias applied to softmax to stay in fp16 range.
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV / 4 + c;
data_ov4[o_offset + offset] = D_TYPEV4(elems);
}
@@ -19,7 +19,6 @@
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
const uint32_t row_split = Bc / MatBc;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
shared float tmpsh[row_split];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
@@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride];
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
const uint vsh_stride = v_cols;
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
const uint32_t osh_stride = row_split * MatBr / 4;
shared f16vec4 pvsh[MatBc * osh_stride];
shared ACC_TYPE slope[Br];
@@ -84,11 +78,6 @@ void main() {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
}
}
barrier();
}
@@ -104,10 +93,10 @@ void main() {
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
f16vec4 Of[rows_per_thread][d_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
Of[r][d] = ACC_TYPEV4(0.0);
Of[r][d] = f16vec4(0.0);
}
}
@@ -153,22 +142,22 @@ void main() {
uint32_t mask_opt = 0;
uint32_t mask_opt_idx = ~0;
uint32_t mask_opt_bits = 0;
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
mask_cache[idx] = f16vec4(0);
}
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
}
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
// skip this block
continue;
@@ -231,24 +220,24 @@ void main() {
}
}
if (K_LOAD_SHMEM != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK_pad / 4);
uint32_t c = (idx + tid) / (HSK_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
kvsh[c * kvsh_stride + d] = K_Tf;
}
}
barrier();
@@ -262,7 +251,11 @@ void main() {
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
if (K_LOAD_SHMEM == 0) {
// If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
// If not, f16 K is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If K is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (KV_bounds_check || d * 16 + 16 > HSK) {
#endif
@@ -277,13 +270,13 @@ void main() {
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
#endif
}
ksh[row * kshstride + col_vec] = K_Tf;
kvsh[row * kvsh_stride + col_vec] = K_Tf;
}
}
barrier();
@@ -295,8 +288,8 @@ void main() {
if (KV_bounds_check || d * 16 + 16 > HSK)
#endif
{
uint coord = (gl_SubgroupID * MatBc) * kshstride;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
#if BLOCK_SIZE == 1
else {
@@ -305,8 +298,8 @@ void main() {
}
#endif
} else {
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
@@ -329,7 +322,7 @@ void main() {
barrier();
}
if (MASK_ENABLE) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
@@ -374,7 +367,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
}
}
@@ -397,19 +390,47 @@ void main() {
}
}
if (SHMEM_STAGING != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSV_pad / 4);
uint32_t c = (idx + tid) / (HSV_pad / 4);
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
f16vec4 V_Tf = f16vec4(0);
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
#endif
}
kvsh[c * kvsh_stride + d] = V_Tf;
}
}
}
barrier();
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
// Each subgroup handles HSV/4 columns
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
// Preload V tiles for [Bc, 16 * num subgroups]
const uint v_rows = Bc;
const uint v_total = v_rows * v_cols;
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
// If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
// If not, f16 V is loaded directly from global memory if aligned, otherwise
// staged through a Bc * MatBr size staging buffer.
// If V is not type f16, then it is always staged for dequantization.
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
// For f16, only preload if not aligned
if (KV_bounds_check) {
@@ -428,44 +449,52 @@ void main() {
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
#if BLOCK_SIZE > 1
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
#endif
} else {
ksh[row * vsh_stride + col] = f16vec4(0.0f);
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
}
}
#if BLOCK_SIZE == 1
}
#endif
}
barrier();
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
const uint o_offset = gl_SubgroupID * MatBr / 4;
if (hsv_offset < HSV_pad) {
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
if (SHMEM_STAGING == 0) {
#if BLOCK_SIZE == 1
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
if (!KV_bounds_check) {
// F16 values can be loaded directly from global memory
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
} else
#endif
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
{
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
} else {
const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
PVMat = coopMatMulAdd(KMat, QMat, PVMat);
}
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
// Store PVMat to pvsh and load into Of
coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
}
// Store SfMat to sfsh and load into Of
const uint osh_stride = row_split * MatBc / 4;
const uint o_offset = gl_SubgroupID * MatBc / 4;
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
const uint hsv_per_tile = row_split * MatBc;
@@ -484,7 +513,7 @@ void main() {
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
const uint local_hsv = (hsv_col - hsv_base) / 4;
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
}
}
}
@@ -500,27 +529,48 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
const uint d_local = d0 / threads_per_rowgroup;
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
const uint row = tile_row(r);
const uint global_row = i * Br + row;
if (global_row < N) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d = d0 + col_tid;
if (d >= HSV/4) break;
data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
}
}
if (global_row < N && col_tid == 0) {
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
}
}
}
@@ -539,7 +589,7 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
Of[r][d_local] *= ACC_TYPE(ms);
Of[r][d_local] *= float16_t(ms);
}
} else {
vs = exp(sink - Mf[r]);
@@ -557,14 +607,14 @@ void main() {
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
Of[r][d_local] *= float16_t(Lfrcp[r]);
#if defined(FLOAT_TYPE_MAX)
Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
#endif
}
}
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -573,9 +623,7 @@ void main() {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
}
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
}
}
}
@@ -586,9 +634,7 @@ void main() {
const uint d = d0 + col_tid;
if (d >= HSV / 4) break;
const uint d_local = d0 / threads_per_rowgroup;
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
}
data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
}
}
}
@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
return elem;
}
// Store O values for non-GQA split_k. Rows are tokens, not heads.
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c < HSV) {
uint32_t o_off = HSV * p.ne1
* (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
}
return elem;
}
// Store L/M values for non-GQA split_k.
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
uint32_t global_row = i * Br + r;
if (global_row < N && c == 0) {
uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+ p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@@ -290,13 +312,19 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
if (p.gqa_ratio > 1) {
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
} else {
coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
}
return;
}
@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
// No coopmats
@@ -622,49 +620,63 @@ void process_shaders() {
}
}
// flash attention
for (const auto& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
for (const bool& fp16 : {false, true}) {
std::map<std::string, std::string> base_dict;
if (fp16) {
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
} else {
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
// flash attention
for (const bool& f16acc : {false, true}) {
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
if (fp16 && f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
}
}
}
}
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+27 -6
View File
@@ -899,7 +899,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
};
const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
GGML_ASSERT(type < GGML_TYPE_COUNT);
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return &type_traits[type];
}
@@ -1265,27 +1266,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
}
int64_t ggml_blck_size(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].blck_size;
}
size_t ggml_type_size(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].type_size;
}
size_t ggml_row_size(enum ggml_type type, int64_t ne) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
assert(ne % ggml_blck_size(type) == 0);
return ggml_type_size(type)*ne/ggml_blck_size(type);
}
double ggml_type_sizef(enum ggml_type type) {
return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
}
const char * ggml_type_name(enum ggml_type type) {
return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].type_name;
}
bool ggml_is_quantized(enum ggml_type type) {
assert(type >= 0);
assert(type < GGML_TYPE_COUNT);
return type_traits[type].is_quantized;
}
@@ -1629,11 +1636,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
const size_t cur_end = cur_offs + cur_size;
// align to GGML_MEM_ALIGN
GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1));
size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
char * const mem_buffer = ctx->mem_buffer;
struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
// integer overflow checks
if (cur_end > SIZE_MAX - size_needed) {
GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed);
return NULL;
}
if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) {
GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__,
cur_end, size_needed, (size_t) GGML_OBJECT_SIZE);
return NULL;
}
if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
__func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
@@ -1702,6 +1721,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
obj_alloc_size = data_size;
}
GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size);
struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
GGML_ASSERT(obj_new);
+116 -14
View File
@@ -15,6 +15,17 @@
#include <string>
#include <vector>
#define GGUF_MAX_STRING_LENGTH (1024*1024*1024)
#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024)
#ifdef _WIN32
# define gguf_ftell _ftelli64
# define gguf_fseek _fseeki64
#else
# define gguf_ftell ftello
# define gguf_fseek fseeko
#endif
template <typename T>
struct type_to_gguf_type;
@@ -217,17 +228,64 @@ struct gguf_context {
};
struct gguf_reader {
FILE * file;
gguf_reader(FILE * file) : file(file) {
// read the remaining bytes once and update on each read
nbytes_remain = file_remain(file);
}
gguf_reader(FILE * file) : file(file) {}
// helper for remaining bytes in a file
static uint64_t file_remain(FILE * file) {
const int64_t cur = gguf_ftell(file);
if (cur < 0) {
return 0;
}
if (gguf_fseek(file, 0, SEEK_END) != 0) {
gguf_fseek(file, cur, SEEK_SET);
return 0;
}
const int64_t end = gguf_ftell(file);
if (end < 0) {
gguf_fseek(file, cur, SEEK_SET);
return 0;
}
gguf_fseek(file, cur, SEEK_SET);
return static_cast<uint64_t>(end - cur);
}
template <typename T>
bool read(T & dst) const {
return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
const size_t size = sizeof(dst);
if (nbytes_remain < size) {
return false;
}
const size_t nread = fread(&dst, 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
template <typename T>
bool read(std::vector<T> & dst, const size_t n) const {
if (n > GGUF_MAX_ARRAY_ELEMENTS) {
return false;
}
if constexpr (std::is_same<T, std::string>::value) {
// strings are prefixed with their length, so we need to account for that
if (n > SIZE_MAX / sizeof(uint64_t)) {
return false;
}
if (nbytes_remain < n * sizeof(uint64_t)) {
return false;
}
} else {
if (n > SIZE_MAX / sizeof(T)) {
return false;
}
if (nbytes_remain < n * sizeof(T)) {
return false;
}
}
dst.resize(n);
for (size_t i = 0; i < dst.size(); ++i) {
if constexpr (std::is_same<T, bool>::value) {
@@ -277,13 +335,33 @@ struct gguf_reader {
if (!read(size)) {
return false;
}
dst.resize(size);
return fread(dst.data(), 1, dst.length(), file) == dst.length();
if (size > GGUF_MAX_STRING_LENGTH) {
GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH);
return false;
}
if (size > nbytes_remain) {
GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain);
return false;
}
dst.resize(static_cast<size_t>(size));
const size_t nread = fread(dst.data(), 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
bool read(void * dst, const size_t size) const {
return fread(dst, 1, size, file) == size;
if (size > nbytes_remain) {
return false;
}
const size_t nread = fread(dst, 1, size, file);
nbytes_remain -= nread;
return nread == size;
}
private:
FILE * file;
mutable uint64_t nbytes_remain;
};
struct gguf_context * gguf_init_empty(void) {
@@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// check that tensor type is within defined range
if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
__func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n",
__func__, info.t.name, info.t.type, GGML_TYPE_COUNT);
ok = false;
break;
}
@@ -618,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
// we require the data section to be aligned, so take into account any padding
if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {
GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
gguf_free(ctx);
return nullptr;
}
// store the current file offset - this is where the data section starts
ctx->offset = ftell(file);
ctx->offset = gguf_ftell(file);
// compute the total size of the data section, taking into account the alignment
{
@@ -657,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
// the ggml_tensor structs to the appropriate locations in the binary blob
// compute the exact size needed for the new ggml_context
const size_t mem_size =
params.no_alloc ?
(n_tensors )*ggml_tensor_overhead() :
(n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
size_t mem_size = 0;
if (params.no_alloc) {
if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
const size_t overhead = n_tensors * ggml_tensor_overhead();
mem_size = overhead;
} else {
if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead();
if (SIZE_MAX - overhead < ctx->size) {
GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
gguf_free(ctx);
return nullptr;
}
mem_size = overhead + ctx->size;
}
struct ggml_init_params pdata = {
/*mem_size =*/ mem_size,
+20
View File
@@ -379,6 +379,7 @@ class MODEL_ARCH(IntEnum):
NEO_BERT = auto()
JINA_BERT_V2 = auto()
JINA_BERT_V3 = auto()
EUROBERT = auto()
BLOOM = auto()
STABLELM = auto()
QWEN = auto()
@@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE_EXP = auto()
FFN_DOWN_EXP = auto()
FFN_UP_EXP = auto()
FFN_GATE_UP_EXP = auto()
FFN_GATE_SHEXP = auto()
FFN_DOWN_SHEXP = auto()
FFN_UP_SHEXP = auto()
@@ -820,6 +822,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.NEO_BERT: "neo-bert",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3",
MODEL_ARCH.EUROBERT: "eurobert",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen",
@@ -978,6 +981,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
@@ -1587,6 +1591,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.EUROBERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN,
],
MODEL_ARCH.MPT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -1805,6 +1822,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
@@ -1894,6 +1912,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
@@ -2595,6 +2614,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
+7 -4
View File
@@ -175,6 +175,9 @@ class GGUFReader:
if new_align.types != [GGUFValueType.UINT32]:
raise ValueError('Bad type for general.alignment field')
self.alignment = new_align.parts[-1][0]
# Ensure alignment is a non-zero power of two
if self.alignment == 0 or (self.alignment & (self.alignment - 1)) != 0:
raise ValueError('Invalid alignment: must be a non-zero power of two')
padding = offs % self.alignment
if padding != 0:
offs += self.alignment - padding
@@ -202,11 +205,11 @@ class GGUFReader:
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
if field.name in self.fields:
# TODO: add option to generate error on duplicate keys
# raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
# TODO: add option to make this a warning and accept duplicate keys like below
raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
self.fields[field.name + '_{}'.format(field.offset)] = field
# logger.warning(f'Duplicate key {field.name} at offset {field.offset}')
# self.fields[field.name + '_{}'.format(field.offset)] = field
else:
self.fields[field.name] = field
return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
+2
View File
@@ -501,6 +501,8 @@ class GGUFWriter:
self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version)
def add_custom_alignment(self, alignment: int) -> None:
if alignment <= 0 or (alignment & (alignment - 1)) != 0:
raise ValueError('Invalid alignment: must be a non-zero power of two')
self.data_alignment = alignment
self.add_uint32(Keys.General.ALIGNMENT, alignment)
+4
View File
@@ -567,6 +567,10 @@ class TensorNameMap:
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
),
MODEL_TENSOR.FFN_GATE_UP_EXP: (
"model.layers.{bid}.mlp.experts.gate_up_proj",
),
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
+19 -21
View File
@@ -25,16 +25,12 @@ Example usage:
"""
def generate_input_prompt(length: int) -> list[str]:
CORPUS = """
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
### Tool Call Format:
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
You can make multiple calls in one go by placing them one after another.
"""
words = [w.strip() for w in CORPUS.strip().split(" ")]
def get_remote_corpus(url: str, length: int) -> list[str]:
response = requests.get(url)
response.raise_for_status()
corpus = response.text
words = [w.strip() for w in corpus.strip().split(" ")]
words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens
words = [w for w in words if len(w) > 0] # filter out empty strings
while len(words) < length:
words += words
@@ -226,9 +222,9 @@ def parse_args() -> argparse.Namespace:
)
parser_dump.add_argument(
"--file",
type=Path,
default=None,
help="File containing prompt to use instead of the default",
type=str,
default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md",
help="File containing prompt to use instead of the default (can also be an URL)",
)
parser_dump.add_argument(
"--pattern",
@@ -259,17 +255,19 @@ def main():
if args.verb == "dump":
pattern = parse_pattern(args.pattern)
input_length = sum(n for _, n in pattern)
input_words = generate_input_prompt(input_length)
if args.file is not None:
with args.file.open("r") as f:
required_words = sum(n for _, n in pattern)
if args.file.startswith("http"):
input_words = get_remote_corpus(args.file, required_words)
logger.info(f"Fetched {len(input_words)} words from remote {args.file}")
else:
with open(args.file, "r") as f:
input_words = f.read().strip().split(" ")
if input_length < sum(n for _, n in pattern):
input_words = [w for w in input_words if len(w) > 0] # filter out empty strings
if len(input_words) < required_words:
raise ValueError(
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words."
)
input_length = len(input_words)
logger.info(f"Using {input_length} words")
logger.info(f"Using {len(input_words)} words")
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
elif args.verb == "compare":
compare_logits(args.input1, args.input2, args.output)
+1
View File
@@ -62,6 +62,7 @@ add_library(llama
models/dream.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
models/eurobert.cpp
models/exaone-moe.cpp
models/exaone.cpp
models/exaone4.cpp
+20
View File
@@ -26,6 +26,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_NEO_BERT, "neo-bert" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" },
{ LLM_ARCH_EUROBERT, "eurobert" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" },
@@ -348,6 +349,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
@@ -819,6 +821,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT,
};
case LLM_ARCH_EUROBERT:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ATTN_NORM,
LLM_TENSOR_ATTN_Q,
LLM_TENSOR_ATTN_K,
LLM_TENSOR_ATTN_V,
LLM_TENSOR_ATTN_OUT,
LLM_TENSOR_FFN_NORM,
LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_DOWN,
};
case LLM_ARCH_MODERN_BERT:
return {
LLM_TENSOR_TOKEN_EMBD,
@@ -989,6 +1005,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1046,6 +1063,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1586,6 +1604,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -2670,6 +2689,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+2
View File
@@ -30,6 +30,7 @@ enum llm_arch {
LLM_ARCH_NEO_BERT,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_JINA_BERT_V3,
LLM_ARCH_EUROBERT,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
LLM_ARCH_QWEN,
@@ -372,6 +373,7 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_GATE_UP_EXPS,
LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP,
+49 -21
View File
@@ -1165,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in) const {
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps) const {
return build_moe_ffn(
cur,
gate_inp, /* gate_inp_b */ nullptr,
@@ -1181,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
w_scale,
gating_op,
il,
probs_in
probs_in,
gate_up_exps
);
}
@@ -1204,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in) const {
ggml_tensor * probs_in,
ggml_tensor * gate_up_exps,
ggml_tensor * gate_up_exps_b) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1343,26 +1347,48 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_weighted", il);
}
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
if (up_exps_b) {
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
ggml_tensor * up = nullptr;
ggml_tensor * experts = nullptr;
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
if (gate_up_exps) {
// merged gate_up path: one mul_mat_id, then split into gate and up views
ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
cb(gate_up, "ffn_moe_gate_up", il);
if (gate_up_exps_b) {
gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
cb(gate_up, "ffn_moe_gate_up_biased", il);
}
const int64_t n_ff = gate_up->ne[0] / 2;
cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
cb(cur, "ffn_moe_gate", il);
up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
cb(up, "ffn_moe_up", il);
} else {
cur = up;
// separate gate and up path
up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
if (up_exps_b) {
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
cb(up, "ffn_moe_up_biased", il);
}
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
} else {
cur = up;
}
if (gate_exps_b) {
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
cb(cur, "ffn_moe_gate_biased", il);
}
}
if (gate_exps_b) {
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
cb(cur, "ffn_moe_gate_biased", il);
}
const bool has_gate = gate_exps || gate_up_exps;
switch (type_op) {
case LLM_FFN_SILU:
@@ -1385,7 +1411,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
break;
}
}
}
if (has_gate) {
cur = ggml_swiglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_swiglu", il);
} else {
@@ -1393,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
if (gate_exps) {
if (has_gate) {
cur = ggml_geglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_geglu", il);
} else {
@@ -1409,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_swiglu_oai", il);
} break;
case LLM_FFN_RELU:
if (gate_exps) {
if (has_gate) {
cur = ggml_reglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_reglu", il);
} else {
@@ -1417,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
cb(cur, "ffn_moe_relu", il);
} break;
case LLM_FFN_RELU_SQR:
if (gate_exps) {
if (has_gate) {
// TODO: add support for gated squared relu
GGML_ABORT("fatal error: gated squared relu not implemented");
} else {
+5 -2
View File
@@ -814,7 +814,8 @@ struct llm_graph_context {
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in = nullptr) const;
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
@@ -835,7 +836,9 @@ struct llm_graph_context {
float w_scale,
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in = nullptr) const;
ggml_tensor * probs_in = nullptr,
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * gate_up_exps_b = nullptr) const;
//
// inputs
+3
View File
@@ -978,6 +978,9 @@ bool llama_kv_cache::get_can_shift() const {
if (model.arch == LLM_ARCH_STEP35) {
return false;
}
if (hparams.n_pos_per_embd() > 1) {
return false;
}
return true;
}
+1 -1
View File
@@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
return false;
}
// invalidate tails which will be cleared
+57 -7
View File
@@ -123,6 +123,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_8B_A1B: return "8B.A1B";
case LLM_TYPE_16B_A1B: return "16B.A1B";
case LLM_TYPE_21B_A3B: return "21B.A3B";
case LLM_TYPE_24B_A2B: return "24B.A2B";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_31B_A3_5B: return "31B.A3.5B";
case LLM_TYPE_35B_A3B: return "35B.A3B";
@@ -978,6 +979,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
type = LLM_TYPE_250M;
}
} break;
case LLM_ARCH_EUROBERT:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
if (hparams.n_layer == 12) {
type = LLM_TYPE_SMALL; // 0.2B
}
} break;
case LLM_ARCH_BLOOM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2381,7 +2392,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
}
type = LLM_TYPE_8B_A1B;
switch (hparams.n_layer) {
case 24: type = LLM_TYPE_8B_A1B; break;
case 40: type = LLM_TYPE_24B_A2B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_SMALLTHINKER:
{
@@ -2965,6 +2980,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// TODO: move to a separate function
const auto tn = LLM_TN(arch);
// helper: try merged gate_up_exps first, fall back to separate gate and up
auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) {
layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED);
if (layer.ffn_gate_up_exps == nullptr) {
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
}
};
switch (arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
@@ -3565,6 +3589,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
}
} break;
case LLM_ARCH_EUROBERT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
}
} break;
case LLM_ARCH_JINA_BERT_V2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -5183,9 +5230,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
// MoE branch
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared expert branch
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
@@ -7387,9 +7433,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
@@ -7453,9 +7498,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
// Shared experts
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
@@ -8176,6 +8220,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_EUROBERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_GEMMA_EMBEDDING:
@@ -8373,6 +8418,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_neo_bert>(*this, params);
} break;
case LLM_ARCH_EUROBERT:
{
llm = std::make_unique<llm_build_eurobert>(*this, params);
} break;
case LLM_ARCH_BLOOM:
{
llm = std::make_unique<llm_build_bloom>(*this, params);
@@ -8999,6 +9048,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_EUROBERT:
case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN:
+11 -8
View File
@@ -116,6 +116,7 @@ enum llm_type {
LLM_TYPE_8B_A1B, // lfm2moe
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_24B_A2B, // lfm2moe
LLM_TYPE_30B_A3B,
LLM_TYPE_31B_A3_5B,
LLM_TYPE_35B_A3B, // Qwen3.5
@@ -279,14 +280,16 @@ struct llama_layer {
struct ggml_tensor * ffn_up_enc = nullptr;
// ff MoE
struct ggml_tensor * ffn_gate_inp = nullptr;
struct ggml_tensor * ffn_gate_exps = nullptr;
struct ggml_tensor * ffn_down_exps = nullptr;
struct ggml_tensor * ffn_up_exps = nullptr;
struct ggml_tensor * ffn_gate_inp_b = nullptr;
struct ggml_tensor * ffn_gate_exps_b = nullptr;
struct ggml_tensor * ffn_down_exps_b = nullptr;
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_inp = nullptr;
struct ggml_tensor * ffn_gate_exps = nullptr;
struct ggml_tensor * ffn_down_exps = nullptr;
struct ggml_tensor * ffn_up_exps = nullptr;
struct ggml_tensor * ffn_gate_up_exps = nullptr;
struct ggml_tensor * ffn_gate_inp_b = nullptr;
struct ggml_tensor * ffn_gate_exps_b = nullptr;
struct ggml_tensor * ffn_down_exps_b = nullptr;
struct ggml_tensor * ffn_up_exps_b = nullptr;
struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
+2 -1
View File
@@ -1890,7 +1890,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "falcon-h1" ||
tokenizer_pre == "pixtral" ||
tokenizer_pre == "midm-2.0" ||
tokenizer_pre == "lfm2") {
tokenizer_pre == "lfm2" ||
tokenizer_pre == "jina-v5-nano") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;
add_bos = true;
+3 -1
View File
@@ -218,7 +218,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
LLM_FFN_SILU, hparams.expert_weights_norm,
hparams.expert_weights_scale, hparams.expert_weights_scale,
(llama_expert_gating_func_type) hparams.expert_gating_func,
il);
il,
nullptr,
model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// FFN shared expert
+97
View File
@@ -0,0 +1,97 @@
#include "models.h"
llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_pos = build_inp_pos();
inpL = build_inp_embd(model.tok_embd);
cb(inpL, "inp_embd", -1);
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * cur = inpL;
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
{
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
Qcur = build_lora_mm(model.layers[il].wq, cur);
Kcur = build_lora_mm(model.layers[il].wk, cur);
Vcur = build_lora_mm(model.layers[il].wv, cur);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn,
model.layers[il].wo, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
cb(cur, "kqv_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpL);
ggml_tensor * ffn_inp = cur;
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_embd", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}
+2
View File
@@ -116,6 +116,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Check layer type by checking which tensors exist
// KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
bool is_kda = (layer.ssm_a != nullptr);
+4
View File
@@ -424,6 +424,10 @@ struct llm_build_neo_bert : public llm_graph_context {
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_eurobert : public llm_graph_context {
llm_build_eurobert(const llama_model & model, const llm_graph_params & params);
};
template <bool iswa>
struct llm_build_olmo2 : public llm_graph_context {
llm_build_olmo2(const llama_model & model, const llm_graph_params & params);
+2 -1
View File
@@ -29,6 +29,8 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+4 -2
View File
@@ -29,6 +29,8 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -379,7 +380,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation
+4 -2
View File
@@ -21,6 +21,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@@ -354,7 +356,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
@@ -478,7 +479,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
nullptr,
n_expert, n_expert_used, LLM_FFN_SILU,
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
nullptr, model.layers[il].ffn_gate_up_exps);
cb(moe_out, "ffn_moe_out", il);
// Add shared experts if present - following Qwen3Next reference implementation
+15 -1
View File
@@ -48,6 +48,7 @@ enum handcrafted_file_type {
HANDCRAFTED_DATA_NOT_ENOUGH_DATA = 10 + offset_has_data,
HANDCRAFTED_DATA_BAD_ALIGN = 15 + offset_has_data,
HANDCRAFTED_DATA_INCONSISTENT_ALIGN = 20 + offset_has_data,
HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW = 30 + offset_has_data,
HANDCRAFTED_DATA_SUCCESS = 800 + offset_has_data,
HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data,
};
@@ -84,6 +85,7 @@ static std::string handcrafted_file_type_name(const enum handcrafted_file_type h
case HANDCRAFTED_DATA_NOT_ENOUGH_DATA: return "DATA_NOT_ENOUGH_DATA";
case HANDCRAFTED_DATA_BAD_ALIGN: return "DATA_BAD_ALIGN";
case HANDCRAFTED_DATA_INCONSISTENT_ALIGN: return "DATA_INCONSISTENT_ALIGN";
case HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW: return "DATA_MEM_SIZE_OVERFLOW";
case HANDCRAFTED_DATA_SUCCESS: return "DATA_SUCCESS";
case HANDCRAFTED_DATA_CUSTOM_ALIGN: return "DATA_CUSTOM_ALIGN";
}
@@ -196,6 +198,13 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
tensor_configs = get_tensor_configs(rng);
}
if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) {
tensor_configs.resize(2);
tensor_configs[0] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } };
tensor_configs[1] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } };
}
if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
const uint64_t n_tensors = -1;
helper_write(file, n_tensors);
@@ -397,7 +406,8 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
for (uint32_t i = 1; i < n_dims; ++i) {
ne *= shape[i];
}
offset += GGML_PAD(ggml_row_size(type, ne), alignment);
offset += GGML_PAD(ggml_row_size(type, ne), (uint64_t) alignment);
}
while (ftell(file) % alignment != 0) {
@@ -411,6 +421,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
if (hft == HANDCRAFTED_DATA_NOT_ENOUGH_DATA) {
nbytes -= 1;
}
if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) {
nbytes = 32;
}
for (uint64_t i = 0; i < nbytes; ++i) {
const uint8_t random_byte = i % 256;
helper_write(file, random_byte);
@@ -704,6 +717,7 @@ static std::pair<int, int> test_handcrafted_file(const unsigned int seed) {
HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
HANDCRAFTED_DATA_BAD_ALIGN,
HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW,
HANDCRAFTED_DATA_SUCCESS,
HANDCRAFTED_DATA_CUSTOM_ALIGN,
};
+7 -2
View File
@@ -13,7 +13,12 @@ fi
name=$1
input=$2
make -j tests/test-tokenizer-0
# Build using CMake if binary doesn't exist
if [ ! -f ./build/bin/test-tokenizer-0 ]; then
printf "Building test-tokenizer-0 with CMake...\n"
cmake -B build -DLLAMA_BUILD_TESTS=ON
cmake --build build --target test-tokenizer-0 -j
fi
printf "Testing %s on %s ...\n" $name $input
@@ -23,7 +28,7 @@ printf "Tokenizing using (py) Python AutoTokenizer ...\n"
python3 ./tests/test-tokenizer-0.py ./models/tokenizers/$name --fname-tok $input > /tmp/test-tokenizer-0-$name-py.log 2>&1
printf "Tokenizing using (cpp) llama.cpp ...\n"
./tests/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1
./build/bin/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1
cat /tmp/test-tokenizer-0-$name-py.log | grep "tokenized in"
cat /tmp/test-tokenizer-0-$name-cpp.log | grep "tokenized in"
+3 -1
View File
@@ -912,7 +912,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_LAST) {
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
}
auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__);
+1 -1
View File
@@ -248,7 +248,7 @@ int32_t mtmd_helper_decode_image_chunk(
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
int32_t i_batch = 0;
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
int32_t n_img_batches = (n_tokens + n_batch - 1) / n_batch;
decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd);
if (mtmd_decode_use_mrope(ctx)) {
+1 -1
View File
@@ -1510,7 +1510,7 @@ version = 1
; If the same key is defined in a specific preset, it will override the value in this global section.
[*]
c = 8192
n-gpu-layer = 8
n-gpu-layers = 8
; If the key corresponds to an existing model on the server,
; this will be used as the default config for that model
+65 -7
View File
@@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
}
llama_pos server_tokens::pos_next() const {
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
if (!has_mtmd) {
return tokens.size();
if (n_tokens < 0) {
return tokens.size();
}
return n_tokens;
}
llama_pos res = tokens.size();
if (n_tokens < 0) {
llama_pos res = tokens.size();
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
const auto & chunk = it->second;
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
}
return res;
}
return res;
int64_t idx = 0;
llama_pos pos = 0;
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
while (idx < n_tokens) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
}
return pos;
}
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
if (!has_mtmd) {
return std::min((size_t)(max_pos + 1), tokens.size());
}
size_t idx = 0;
llama_pos pos = 0;
while (idx < tokens.size()) {
const auto media_it = map_idx_to_media.find(idx);
if (media_it != map_idx_to_media.end()) {
const auto & chunk = media_it->second;
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
pos += n_pos;
idx += n_tok;
} else {
pos++;
idx++;
}
if (pos > max_pos) {
break;
}
}
return idx;
}
std::string server_tokens::str() const {
+6 -1
View File
@@ -167,7 +167,12 @@ public:
// for debugging
std::string str() const;
llama_pos pos_next() const;
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
llama_pos pos_next(int64_t n_tokens = -1) const;
// number of tokens with position <= max_pos
size_t size_up_to_pos(llama_pos max_pos) const;
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
void push_back(llama_token tok);
+26 -29
View File
@@ -995,9 +995,6 @@ private:
// don't update the cache if the slot's context is empty
update_cache = update_cache && tokens.size() > 0;
// TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr);
if (update_cache) {
SRV_WRN("%s", "updating prompt cache\n");
@@ -1442,7 +1439,7 @@ private:
res->id = slot.task->id;
res->id_slot = slot.id;
res->index = slot.task->index;
res->index = slot.task->index;
// keep copy of last generated text for debugging purposes
if (slots_debug) {
@@ -2282,15 +2279,15 @@ private:
n_past = 0;
}
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
const auto n_swa = std::max(1, llama_model_n_swa(model));
// the largest pos_min required for a checkpoint to be useful
const auto pos_min_thold = std::max(0, n_past - n_swa);
const auto pos_min_thold = std::max(0, pos_next - n_swa);
// note: disallow with mtmd contexts for now
// https://github.com/ggml-org/llama.cpp/issues/17043
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
if (pos_min == -1) {
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
@@ -2341,9 +2338,6 @@ private:
}
if (pos_min > pos_min_thold) {
// TODO: support can be added in the future when corresponding vision models get released
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
// search for a context checkpoint
@@ -2364,18 +2358,20 @@ private:
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
do_reset = true;
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
pos_next = 0;
n_past = 0;
}
}
@@ -2386,7 +2382,7 @@ private:
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
const auto & cur = *it;
if (cur.pos_min > pos_min_thold) {
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
it = slot.prompt.checkpoints.erase(it);
} else {
++it;
@@ -2402,7 +2398,7 @@ private:
SLT_WRN(slot, "n_past was set to %d\n", n_past);
}
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_cache = n_past;
slot.n_prompt_tokens_processed = 0;
slot.prompt.tokens.keep_first(n_past);
@@ -2520,10 +2516,6 @@ private:
}
}
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
// entire prompt has been processed
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
slot.state = SLOT_STATE_DONE_PROMPT;
@@ -2536,8 +2528,6 @@ private:
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
slot.init_sampler();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -2549,13 +2539,15 @@ private:
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
@@ -2563,16 +2555,21 @@ private:
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
}
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
} else {
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
}
}
+13
View File
@@ -339,6 +339,17 @@ static std::map<std::string, std::string> get_headers(const httplib::Request & r
return headers;
}
static std::string build_query_string(const httplib::Request & req) {
std::string qs;
for (const auto & [key, value] : req.params) {
if (!qs.empty()) {
qs += '&';
}
qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value);
}
return qs;
}
// using unique_ptr for request to allow safe capturing in lambdas
using server_http_req_ptr = std::unique_ptr<server_http_req>;
@@ -382,6 +393,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
req.is_connection_closed
});
@@ -396,6 +408,7 @@ void server_http_context::post(const std::string & path, const server_http_conte
get_params(req),
get_headers(req),
req.path,
build_query_string(req),
req.body,
req.is_connection_closed
});
+2 -1
View File
@@ -36,7 +36,8 @@ using server_http_res_ptr = std::unique_ptr<server_http_res>;
struct server_http_req {
std::map<std::string, std::string> params; // path_params + query_params
std::map<std::string, std::string> headers; // reserved for future use
std::string path; // reserved for future use
std::string path;
std::string query_string; // query parameters string (e.g. "action=save")
std::string body;
const std::function<bool()> & should_stop;
+8 -2
View File
@@ -291,7 +291,9 @@ void server_models::load_models() {
for (const auto & [name, inst] : mapping) {
std::string val;
if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
models_to_load.push_back(name);
if (common_arg_utils::is_truthy(val)) {
models_to_load.push_back(name);
}
}
}
if ((int)models_to_load.size() > base_params.models_max) {
@@ -697,11 +699,15 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
mapping[name].meta.last_used = ggml_time_ms();
}
SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
std::string proxy_path = req.path;
if (!req.query_string.empty()) {
proxy_path += '?' + req.query_string;
}
auto proxy = std::make_unique<server_http_proxy>(
method,
CHILD_ADDR,
meta->port,
req.path,
proxy_path,
req.headers,
req.body,
req.should_stop,
+3 -3
View File
@@ -204,7 +204,8 @@ task_params server_task::params_from_json_cmpl(
params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt);
params.return_tokens = json_value(data, "return_tokens", false);
params.return_progress = json_value(data, "return_progress", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
auto max_tokens = json_value(data, "max_tokens", defaults.n_predict);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
@@ -1899,10 +1900,9 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t
return nullptr;
}
// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto & cur = states.emplace_back();
cur = {
/*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false),
/*.tokens =*/ prompt.tokens.clone(),
/*.data =*/ std::move(state_data),
/*.checkpoints =*/ prompt.checkpoints,
};
+2
View File
@@ -557,6 +557,8 @@ struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
int64_t n_tokens;
std::vector<uint8_t> data;
size_t size() const {