Compare commits

...

57 Commits

Author SHA1 Message Date
Anton Mitkov 2bf9d539dd sycl: GGML_SYCL_DISABLE_OPT on by default for all Intel Devices (#13973) 2025-06-25 18:09:55 +02:00
lhez 73e53dc834 opencl: ref count ggml_backend_opencl_context and refactor profiling (#14254)
* Move profiling info into `ggml_backend_opencl_context`
* Add `enqueue_ndrange_kernel` to launch kernel
2025-06-24 11:46:25 -07:00
Georgi Gerganov 62af464227 batch : fix check for empty sequences in memory (#14364)
* batch : fix check for empty sequences in memory

ggml-ci

* cont : reuse the var

ggml-ci
2025-06-24 18:26:30 +03:00
Mathieu Baudier c148cf1946 cmake : use LLAMA_BUILD_NUMBER when defining LLAMA_INSTALL_VERSION (#14362) 2025-06-24 15:05:31 +02:00
Nigel Bosch 1b809cee22 server : move no API key doc to /health (#14352) 2025-06-24 10:59:11 +02:00
Sigbjørn Skjæret abf241045d main : honor --verbose-prompt on interactive prompts (#14350) 2025-06-24 09:31:00 +02:00
Bartowski 901e20bbe5 jinja : Add Mistral-Small-3.2-24B-Instruct-2506.jinja (#14349)
This will allow the use of tools on the llama-server
2025-06-24 09:17:58 +03:00
uvos 0142961a2e CUDA/HIP: optimize mmv paths taken for HIP devices (#14324)
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2025-06-24 01:12:56 +02:00
bandoti ce82bd0117 ci: add workflow for relocatable cmake package (#14346) 2025-06-23 15:30:51 -03:00
Jeff Bolz bf2a99e3cb vulkan: update windows SDK in release.yml (#14344) 2025-06-23 15:44:48 +02:00
Molly Sophia 72c6bc3f3d llama : better rwkv chat template and add missing inputs.use_jinja setting (#14336)
* llama-cli : add missing `inputs.use_jinja` setting

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

* llama : better legacy chat template for rwkv

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>

---------

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
2025-06-23 19:56:19 +08:00
Johannes Gäßler defe2158dd CUDA: mul_mat_v support for batch sizes > 1 (#14262)
* CUDA: mul_mat_v support for batch sizes > 1

* use 64 bit math for initial offset calculation
2025-06-23 13:11:31 +02:00
Georgi Gerganov 7b50d589a8 kv-cells : fix tracking of seq_pos (#14339)
* kv-cells : fix tracking of seq_pos during cache reuse

ggml-ci

* cont : improve error message

ggml-ci

* cont : add more comments
2025-06-23 12:27:35 +03:00
Jeff Bolz 3a9457df96 vulkan: update windows SDK in CI (#14334) 2025-06-23 10:19:24 +02:00
Ed Addario fa4a9f2a1c quantize : handle user-defined pruning of whole layers (blocks) (#13037) 2025-06-22 23:16:26 +02:00
Sigbjørn Skjæret 238005c2dc gguf-py : fix SpecialVocab parsing when post_processor is null (#14330) 2025-06-22 19:46:17 +02:00
Ruikai Peng 66aba7aca9 run : avoid double tokenization (#14327)
* run : avoid double tokenization by adopting common_tokenize heuristic

* build : fix windows gcc and clang warnings

* lint : fixed trailing whitepace

* run : fix is_first flag
2025-06-23 01:28:06 +08:00
Georgi Gerganov f1f5e82df6 examples : fix is_first logic for tokenization (#14329)
ggml-ci
2025-06-22 20:10:07 +03:00
uvos af3373f1ad HIP: enable vec fattn on RDNA4 (#14323) 2025-06-22 16:51:23 +02:00
yuiseki 5d5c066de8 mtmd : fix Pixtral OOM with large images by capping image_size to 1024 (#14326)
Mistral Small 2506 models using Pixtral vision encoder were running out
of GPU memory when processing images larger than 1024x1024 pixels due to
exponential memory growth from unlimited image size.

This fix applies the same 1024x1024 limit used by Qwen2VL models to
prevent OOM issues while maintaining compatibility with existing models.
2025-06-22 14:44:57 +02:00
Sigbjørn Skjæret 40bfa04c95 common : use std::string_view now that we target c++17 (#14319) 2025-06-22 08:37:43 +03:00
Aman Gupta aa064b2eb7 CUDA: add mean operation (#14313)
* CUDA: add mean operation

* add back sum_rows_f32_cuda

* Review: early exit if col!=0
2025-06-22 12:39:54 +08:00
Sigbjørn Skjæret aa0ef5c578 gguf-py : fix Qwen3-Embedding eos token (#14314) 2025-06-21 18:12:05 +02:00
Markus Tavenrath bb16041cae Add support for VK_EXT_debug_utils to add labels to Vulkan objects. (#13792)
* Add support for VK_EXT_debug_utils to add labels to Vulkan objects. In step 1 compute pipelines are getting labeled.

* remove #ifdef for debug utils and add queue marker.
2025-06-21 08:17:12 +02:00
Sigbjørn Skjæret 58cba76a9a gguf-py : fix TemplateProcessing pair when bos/eos is missing (#14312) 2025-06-21 07:33:21 +02:00
Georgi Gerganov 67ae5312e2 metal : fix thread-safety (#14300)
ggml-ci
2025-06-21 08:04:18 +03:00
Georgi Gerganov 692e3cdd0a memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i

ggml-ci

* cont : fix comments

* cont : use "mctx" for referencing a memory context

ggml-ci
2025-06-21 08:03:46 +03:00
Daniel Han b23fa0b3f4 convert : fix Llama 4 conversion (#14311) 2025-06-21 06:32:01 +02:00
Georgi Gerganov 06cbedfca1 sync : ggml
ggml-ci
2025-06-20 21:02:47 +03:00
Acly b7147673f2 Add ggml_roll (ggml/1274)
* ggml : add ggml_roll

* use set/get_op_params & std::min
2025-06-20 21:02:47 +03:00
David Chiu d860dd99a4 docs : fix the link to llama.h (#14293) 2025-06-20 19:43:35 +02:00
Aman Gupta c959f462a0 CUDA: add conv_2d_transpose (#14287)
* CUDA: add conv_2d_transpose

* remove direct include of cuda_fp16

* Review: add brackets for readability, remove ggml_set_param and add asserts
2025-06-20 22:48:24 +08:00
Sigbjørn Skjæret 22015b2092 lint : remove trailing whitepace (#14304) 2025-06-20 16:37:44 +02:00
Ruikai Peng dd6e6d0b6a vocab : prevent tokenizer overflow (#14301)
* vocab : prevent stack overflow in tokenize

* vocab : return error instead of aborting on oversized token count

* vocab : INT32_MIN from llama_tokenize on overflow
2025-06-20 07:13:06 -07:00
Nicolò Scipione 8308f98c7f sycl: add usage of enqueue_functions extension (#14244)
* Add header and namespace to use enqueue_functions extension

* Convert submit and parallel_for to use new extension in convert.cpp

* Convert submit and parallel_for to use extension in ggml-sycl.cpp

* Convert submit and parallel_for to use extension in gla.cpp

* Convert submit and parallel_for in mmq.cpp

* Convert submit and parallel_for in mmvq.cpp

* Convert submit and parallel_for in remaining files

* Convert all simple parallel_for to nd_launch from enqueue_functions
extension

* Wrapping extension in general function

Create a general function that enable the enqueue_functions extension if
it is enable in the compiler, otherwise call the general SYCL function
to launch kernels.

---------

Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
2025-06-20 15:07:21 +02:00
Christian Kastner 6369be0735 Implement GGML_CPU_ALL_VARIANTS for PowerPC (#14286)
* Add PowerPC feature detection and scoring

* ggml-cpu: Implement GGML_CPU_ALL_VARIANTS for PowerPC

* ggml-cpu: Delay some initializations until function is called

When using GGML_BACKEND_DL=ON, these initializations might use
instructions that are not supported by the current CPU.

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-06-20 14:17:32 +02:00
Sigbjørn Skjæret 88fc854b4b llama : improve sep token handling (#14272) 2025-06-20 14:04:09 +02:00
Diego Devesa e28c1b93fd cuda : synchronize graph capture and cublas handle destruction (#14288)
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread
2025-06-20 13:57:36 +02:00
Georgi Gerganov d27b3ca175 ggml : fix repack work size for mul_mat_id (#14292)
ggml-ci
2025-06-20 11:19:15 +03:00
Charles Xu 9230dbe2c7 ggml: Update KleidiAI to v1.9.0 (#14277) 2025-06-20 10:51:01 +03:00
Georgi Gerganov 812939a9e9 model : more uniform output id handling (#14275)
* model : more uniform output id handling

ggml-ci

* cont : revert n_outputs < n_tokens optimization

ggml-ci

* cont : fix out_ids initialization

ggml-ci
2025-06-20 10:50:27 +03:00
Georgi Gerganov 4c9fdfbe15 ubatch : new splitting logic (#14217)
ggml-ci
2025-06-20 10:14:14 +03:00
Aman Gupta 9eaa51e7f0 CUDA: add conv_2d_dw (#14265)
* CUDA: add conv_2d_dw

* better naming

* simplify using template

* Review: fix operation ordering in ggml-cuda, use __forceinline__, use more const
2025-06-20 09:50:24 +08:00
Diego Devesa 8f71d0f3e8 ggml-cpu : remove unnecesary arm feature detection (#14281)
Support for Arm runtime feature detection has now been added to GGML_CPU_ALL_VARIANTS. This removes the old and not very functional code.
2025-06-19 21:24:14 +02:00
Alex Trotta 381174bbda gguf-py : make sentencepiece optional (#14200)
* Make sentencepiece optional

* Bump to 0.18.0

* Bump patch instead of minor

Co-authored-by: compilade <git@compilade.net>

---------

Co-authored-by: compilade <git@compilade.net>
2025-06-19 15:56:12 +02:00
aa956 d67341dc18 server : add server parameters for draft model cache type (#13782)
Co-authored-by: aa956 <27946957+aa956@users.noreply.github.com>
2025-06-19 16:01:03 +03:00
fanyang 456af35eb7 build : suppress gcc15 compile warnings (#14261)
* Change _contains_any() substrs to std::string_view and fix the find comparison logic.
2025-06-19 14:49:48 +02:00
Anton Mitkov 600e3e9b50 sycl: Cleanup codepaths in Get Rows in sycl backend (#14215)
Addresses unused reorder path
2025-06-19 11:40:21 +01:00
bashayer hijji fffcce535e llama-bench : add --no-warmup flag (#14224) (#14270)
Add no_warmup parameter to cmd_params struct and command-line parsing to allow users to skip warmup runs before benchmarking.

- Add no_warmup boolean field to cmd_params struct

- Add --no-warmup command-line argument parsing

- Add help text documentation for the new flag

- Wrap existing warmup logic in conditional check

- Maintain full backward compatibility (warmup enabled by default)

Addresses #14224
2025-06-19 12:24:12 +02:00
pqnet 5fc7856815 convert : fix remote option in Windows (#14100) 2025-06-19 12:21:40 +02:00
Aaron Teo faed5a5f5d llamafile : support s390x SIMD instruction set (#14273) 2025-06-19 11:48:54 +02:00
0cc4m 10bb545c5b Vulkan: Set device max size for host memory to avoid OOM warning and fallback to CPU buffer (#14249) 2025-06-19 09:15:42 +02:00
Gabe Goodhart edc4a29eff memory : Hybrid recurrent cache (#13979)
* feat: Add llama_model_is_hybrid API call

Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add c++ side constants for attention layer indices hparam

Branch: GraniteFour

* feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: rename *_is_hybrid -> *_is_hybrid_recurrent

The implementation of the hybrid cache intentionally does not specify the
types of the child caches, so there was a naming mismatch with these
predicate functions that used "hybrid" to imply "hybrid recurrent."

Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Add layer filter to recurrent cache

Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use per-layer sizing everywhere in kv caches

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: First pass at llama_kv_cache_hybrid_recurrent

This follows the pattern in iswa where the two child caches are held
explicitly to support the case where a model requires a single attention
cache and a single recurrent cache where each layer uses exactly one of the
caches.

This is a rewrite of the more generic approach in the original hybrid cache
PR: https://github.com/ggml-org/llama.cpp/pull/13276

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Construct hybrid recurrent cache for hybrid recurrent models

This includes a refactor of the create_memory logic to avoid needing to use
the arch enum explicitly unless a model needs explicit cache instantiation
logic beyond the standard logic for recurrent, hybrid, unified, and iswa.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix wrong bool condition for split equal in hybrid cache

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix shift logic to defer to unified cache

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Support hybrid recurrent in llama-graph

NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix logic for initializing inputs and attn layers for hybrid caches

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Update recurrent cache for changes to remove intermediate kv_cache interface

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix status for init_update sig for recurrent cache state

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Add missing padding to n_ctx for hybrid cache construction

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Update clear signature for data argument after rebase

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove errant virtual destructor leftover from previous impl attempt

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Remove n_embd_k/v_s from unified cache

No longer needed now that unified isn't also supporting recurrent

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069

Branch: HybridRecurrentCache

* refactor: Remove layer index from n_embd_k/v_s

Now that it's not used at all in the unified cache, we don't need to use
the layer index to zero it out for attention layers.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Remove n_embd_k/v_gqa from recurrent cache

This is no longer needed now that there are separate implementations

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Allow custom layer filters for hybrid recurrent

This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove logits_all after rebase

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Remove llama_model_is_hybrid_Recurrent public API

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Use llama_memory_state_ptr for child states in hybrid memory state

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern

https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738

This is a big overhaul to bring consistency between how inputs and per-
layer components are created for attention layers and recurrent layers. The
main changes are:

- Rename class llm_graph_input_s_copy -> llm_graph_input_rs
- Add a corresponding llm_graph_input_rs_hybrid_recurrent
- Rename build_inp_s_copy -> build_rs_inp_recurrent
- Add a corresponding build_rs_inp_hybrid_recurrent
- Rename build_recurrent_state -> build_rs to match build_attn w/
llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a corresponding overload of build_rs w/
llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to
llm_graph_input_attn_kv_unified
- Add a build_attn override that takes
llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input

This makes the two paradigms fully consistent. The main drawback is the
code duplication in the build_attn and build_rs implementations where the
only difference between implementations is how they cast the memory state.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix: Fix resize vs reserve and skip null tensors in size computation

https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-Authored-By: @younesbelkada

* fix: Fix initialization of child states

Since initially writing this PR, the logic in the child state types changed
such that using the "init full" signature and keeping the ubatches on the
parent struct no longer worked.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Use a common build_recurrent_state method that is cache-agnostic

This reduces the code duplication between the different build_rs impls and
also retains a similar signature to the previous build_recurrent_state
method while standardizing on the input-dispatched build_rs implementation.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* recurrent : rework graph inputs + add TODOs

ggml-ci

* refactor: Make status and child states const in hybrid and iswa

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remove kv cache

This removes the notion of "kv" from the interface names for these memory
types. There are still many references to kv in the implementation of the
recurrent memory which will need further adjustment.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor!: Rename all k/v related values for recurrent/hybrid to r/s

Anywhere that "kv_<state|cell|size|etc>" is used, I've used the more
generic "mem_" prefix. The specifics of "k" (key) translate to "r"
(recurrent state) and "v" (value) translate to "s" (state-space embedding
states).

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refacor: _recurrent -> _recr for brevity

It just _happens_ to have the same number of letters as _attn!

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Fix spacing for ref

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* refactor: recurrent_layer() -> is_recurrent()

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Fix spacing for size_s_bytes declaration

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-06-19 08:08:14 +03:00
Georgi Gerganov ed3290ab34 metal : add mean kernel (#14267)
* metal : add mean kernel

ggml-ci

* cont : dedup implementation

ggml-ci
2025-06-19 08:05:21 +03:00
Aaron Teo 8d94713654 docs: add s390x build documentation (#14264)
* docs: add s390x-specific build docs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: add s390x model conversion steps

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: s390x build indent

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: update hyperlinks for s390x docs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: update llama.h docs

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: s390x add accelerator and perf optimizations

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: s390x indent blocks

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: revert block indentation

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: add support information for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: s390x reword

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: remove indentation for accelerator section s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: remove redundant words s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: reword for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: s390x reword simd

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

* docs: fix trailing whitespace for s390x

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>

---------

Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-06-18 18:10:26 +01:00
Aaron Teo 50d2227953 ggml-cpu: reduce asm calls for hsum (#14037)
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-06-18 18:10:08 +01:00
Aaron Teo 6231c5cd6d ggml-cpu: fix uncaught underscore terminators (#14023)
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
2025-06-18 18:06:49 +01:00
112 changed files with 6801 additions and 5097 deletions
+51
View File
@@ -0,0 +1,51 @@
name: Build relocatable cmake package
on:
workflow_dispatch:
workflow_call:
jobs:
linux:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Install dependencies
run: |
sudo apt update
sudo apt install -y build-essential tcl
- name: Build
run: |
PREFIX="$(pwd)"/inst
cmake -S . -B build -DCMAKE_PREFIX_PATH="$PREFIX" \
-DLLAMA_CURL=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=OFF \
-DLLAMA_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=Release
cmake --build build --config Release
cmake --install build --prefix "$PREFIX" --config Release
export LLAMA_CONFIG="$PREFIX"/lib/cmake/llama/llama-config.cmake
tclsh <<'EOF'
set build(commit) [string trim [exec git rev-parse --short HEAD]]
set build(number) [string trim [exec git rev-list --count HEAD]]
set build(version) "0.0.$build(number)"
set llamaconfig [read [open "$env(LLAMA_CONFIG)" r]]
set checks [list "set\\(LLAMA_VERSION \\s+$build(version)\\)" \
"set\\(LLAMA_BUILD_COMMIT\\s+$build(commit)\\)" \
"set\\(LLAMA_BUILD_NUMBER\\s+$build(number)\\)"]
puts -nonewline "Checking llama-config.cmake version... "
foreach check $checks {
if {![regexp -expanded -- $check $llamaconfig]} {
puts "\"$check\" failed!"
exit 1
}
}
puts "success."
EOF
cd examples/simple-cmake-pkg
cmake -S . -B build -DCMAKE_PREFIX_PATH="$PREFIX"/lib/cmake
cmake --build build
+40 -4
View File
@@ -5,10 +5,43 @@ on:
push:
branches:
- master
paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
paths: [
'.github/workflows/build.yml',
'.github/workflows/build-linux-cross.yml',
'.github/workflows/build-cmake-pkg.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp'
]
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp']
paths: [
'.github/workflows/build.yml',
'.github/workflows/build-linux-cross.yml',
'.github/workflows/build-cmake-pkg.yml',
'**/CMakeLists.txt',
'**/.cmake',
'**/*.h',
'**/*.hpp',
'**/*.c',
'**/*.cpp',
'**/*.cu',
'**/*.cuh',
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp'
]
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
@@ -478,6 +511,9 @@ jobs:
build-linux-cross:
uses: ./.github/workflows/build-linux-cross.yml
build-cmake-pkg:
uses: ./.github/workflows/build-cmake-pkg.yml
macOS-latest-cmake-ios:
runs-on: macos-latest
@@ -683,7 +719,7 @@ jobs:
env:
OPENBLAS_VERSION: 0.3.23
SDE_VERSION: 9.33.0-2024-01-07
VULKAN_VERSION: 1.4.309.0
VULKAN_VERSION: 1.4.313.2
strategy:
matrix:
@@ -736,7 +772,7 @@ jobs:
id: get_vulkan
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"
+2 -2
View File
@@ -302,7 +302,7 @@ jobs:
env:
OPENBLAS_VERSION: 0.3.23
VULKAN_VERSION: 1.4.309.0
VULKAN_VERSION: 1.4.313.2
strategy:
matrix:
@@ -332,7 +332,7 @@ jobs:
id: get_vulkan
if: ${{ matrix.backend == 'vulkan' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"
+1 -1
View File
@@ -95,7 +95,7 @@ endif()
if (NOT DEFINED LLAMA_BUILD_COMMIT)
set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT})
endif()
set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER})
set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER})
# override ggml options
set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS})
+1 -1
View File
@@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output
# rerank score 0: 0.029
+33
View File
@@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.embd_sep = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--cls-separator"}, "STRING",
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
[](common_params & params, const std::string & value) {
params.cls_sep = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--host"}, "HOST",
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
@@ -3210,6 +3217,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.model.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
"KV cache data type for K for the draft model\n"
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
ggml_type_name(params.speculative.cache_type_k)
),
[](common_params & params, const std::string & value) {
params.speculative.cache_type_k = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
add_opt(common_arg(
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
string_format(
"KV cache data type for V for the draft model\n"
"allowed values: %s\n"
"(default: %s)",
get_all_kv_cache_types().c_str(),
ggml_type_name(params.speculative.cache_type_v)
),
[](common_params & params, const std::string & value) {
params.speculative.cache_type_v = kv_cache_type_from_str(value);
}
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
add_opt(common_arg(
{"-mv", "--model-vocoder"}, "FNAME",
+9
View File
@@ -706,11 +706,17 @@ bool fs_validate_filename(const std::string & filename) {
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
@@ -1284,6 +1290,9 @@ std::vector<llama_token> common_tokenize(
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens == std::numeric_limits<int32_t>::min()) {
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
}
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+4
View File
@@ -199,6 +199,9 @@ struct common_params_speculative {
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
@@ -355,6 +358,7 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
std::string cls_sep = "\t"; // separator of classification sequences
// server params
int32_t port = 8080; // server listens on this network port
+3 -46
View File
@@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
class string_view {
const std::string & _str;
const size_t _start;
const size_t _end;
public:
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
size_t size() const {
return _end - _start;
}
size_t length() const {
return size();
}
operator std::string() const {
return str();
}
std::string str() const {
return _str.substr(_start, _end - _start);
}
string_view substr(size_t pos, size_t len = std::string::npos) const {
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
}
char operator[](size_t pos) const {
auto index = _start + pos;
if (index >= _end) {
throw std::out_of_range("string_view index out of range");
}
return _str[_start + pos];
}
bool operator==(const string_view & other) const {
std::string this_str = *this;
std::string other_str = other;
return this_str == other_str;
}
};
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
auto has_min = min_value != std::numeric_limits<int>::min();
auto has_max = max_value != std::numeric_limits<int>::max();
@@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
}
out << "}";
};
std::function<void(const string_view &, const string_view &)> uniform_range =
[&](const string_view & from, const string_view & to) {
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
[&](const std::string_view & from, const std::string_view & to) {
size_t i = 0;
while (i < from.length() && i < to.length() && from[i] == to[i]) {
i++;
}
if (i > 0) {
out << "\"" << from.substr(0, i).str() << "\"";
out << "\"" << from.substr(0, i) << "\"";
}
if (i < from.length() && i < to.length()) {
if (i > 0) {
+12 -24
View File
@@ -2145,7 +2145,6 @@ class Llama4Model(LlamaModel):
def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
def set_gguf_parameters(self):
super().set_gguf_parameters()
@@ -2194,7 +2193,7 @@ class Llama4VisionModel(MmprojModel):
name += ".weight"
if "multi_modal_projector.linear_1" in name:
# despite the name with number postfix, this is a single fully connected layer
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)]
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
return [(self.map_tensor_name(name), data_torch)]
return []
@@ -3918,9 +3917,6 @@ class BertModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
class DistilBertModel(BertModel):
@@ -3962,8 +3958,6 @@ class RobertaModel(BertModel):
bpe_tok_path = self.dir_model / "tokenizer.json"
if bpe_tok_path.exists():
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
@@ -4848,8 +4842,6 @@ class JinaBertV2Model(BertModel):
self.gguf_writer.add_token_type_count(2)
else:
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("OpenELMForCausalLM")
@@ -5451,9 +5443,6 @@ class T5Model(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -5591,9 +5580,6 @@ class T5EncoderModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@@ -6389,8 +6375,8 @@ def parse_args() -> argparse.Namespace:
help="model is executed on big endian machine",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file",
"model", type=str,
help="directory containing model file or huggingface repository ID (if --remote)",
nargs="?",
)
parser.add_argument(
@@ -6493,18 +6479,20 @@ def main() -> None:
else:
logging.basicConfig(level=logging.INFO)
dir_model = args.model
if args.remote:
hf_repo_id = args.model
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
repo_id=str(dir_model),
repo_id=hf_repo_id,
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
dir_model = Path(local_dir)
logger.info(f"Downloaded config and tokenizer to {local_dir}")
else:
hf_repo_id = None
dir_model = Path(args.model)
if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
logger.error(f'Error: {dir_model} is not a directory')
sys.exit(1)
ftype_map: dict[str, gguf.LlamaFileType] = {
@@ -6524,9 +6512,9 @@ def main() -> None:
if args.outfile is not None:
fname_out = args.outfile
elif args.remote:
elif hf_repo_id:
# if remote, use the model ID as the output file name
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
else:
fname_out = dir_model
@@ -6555,7 +6543,7 @@ def main() -> None:
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split,
remote_hf_model_id=str(args.model) if args.remote else None)
remote_hf_model_id=hf_repo_id)
if args.vocab_only:
logger.info("Exporting model vocab...")
+1 -1
View File
@@ -757,7 +757,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function |
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
+157
View File
@@ -0,0 +1,157 @@
> [!IMPORTANT]
> This build documentation is specific only to IBM Z & LinuxONE mainframes (s390x). You can find the build documentation for other architectures: [build.md](build.md).
# Build llama.cpp locally (for s390x)
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
**To get the code:**
```bash
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
```
## CPU Build with BLAS
Building llama.cpp with BLAS support is highly recommended as it has shown to provide performance improvements.
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS
cmake --build build --config Release -j $(nproc)
```
**Notes**:
- For faster repeated compilation, install [ccache](https://ccache.dev/)
- By default, VXE/VXE2 is enabled. To disable it (not recommended):
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DGGML_VXE=OFF
cmake --build build --config Release -j $(nproc)
```
- For debug builds:
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Debug \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS
cmake --build build --config Debug -j $(nproc)
```
- For static builds, add `-DBUILD_SHARED_LIBS=OFF`:
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DBUILD_SHARED_LIBS=OFF
cmake --build build --config Release -j $(nproc)
```
## Getting GGUF Models
All models need to be converted to Big-Endian. You can achieve this in three cases:
1. **Use pre-converted models verified for use on IBM Z & LinuxONE (easiest)**
You can find popular models pre-converted and verified at [s390x Ready Models](hf.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08).
These models and their respective tokenizers are verified to run correctly on IBM Z & LinuxONE.
2. **Convert safetensors model to GGUF Big-Endian directly (recommended)**
```bash
python3 convert_hf_to_gguf.py \
--outfile model-name-be.f16.gguf \
--outtype f16 \
--bigendian \
model-directory/
```
For example,
```bash
python3 convert_hf_to_gguf.py \
--outfile granite-3.3-2b-instruct-be.f16.gguf \
--outtype f16 \
--bigendian \
granite-3.3-2b-instruct/
```
3. **Convert existing GGUF Little-Endian model to Big-Endian**
```bash
python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG
```
For example,
```bash
python3 gguf-py/gguf/scripts/gguf_convert_endian.py granite-3.3-2b-instruct-le.f16.gguf BIG
mv granite-3.3-2b-instruct-le.f16.gguf granite-3.3-2b-instruct-be.f16.gguf
```
**Notes:**
- The GGUF endian conversion script may not support all data types at the moment and may fail for some models/quantizations. When that happens, please try manually converting the safetensors model to GGUF Big-Endian via Step 2.
## IBM Accelerators
### 1. SIMD Acceleration
Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14 or EC13. In such systems, the APIs can still run but will use a scalar implementation.
### 2. zDNN Accelerator
*Only available in IBM z16 or later system. No direction at the moment.*
### 3. Spyre Accelerator
*No direction at the moment.*
## Performance Tuning
### 1. Virtualization Setup
It is strongly recommended to use only LPAR (Type-1) virtualization to get the most performance.
Note: Type-2 virtualization is not supported at the moment, while you can get it running, the performance will not be the best.
### 2. IFL (Core) Count
It is recommended to allocate a minimum of 8 shared IFLs assigned to the LPAR. Increasing the IFL count past 8 shared IFLs will only improve Prompt Processing performance but not Token Generation.
Note: IFL count does not equate to vCPU count.
### 3. SMT vs NOSMT (Simultaneous Multithreading)
It is strongly recommended to disable SMT via the kernel boot parameters as it negatively affects performance. Please refer to your Linux distribution's guide on disabling SMT via kernel boot parameters.
### 4. BLAS vs NOBLAS
IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongly recommended to use BLAS.
## Getting Help on IBM Z & LinuxONE
1. **Bugs, Feature Requests**
Please file an issue in llama.cpp and ensure that the title contains "s390x".
2. **Other Questions**
Please reach out directly to [aionz@us.ibm.com](mailto:aionz@us.ibm.com).
+1 -1
View File
@@ -1,6 +1,6 @@
# Build llama.cpp locally
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
+30 -4
View File
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
// max batch size
const uint64_t n_batch = params.n_batch;
// get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) {
auto inp = common_tokenize(ctx, prompt, true, true);
std::vector<llama_token> inp;
// split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
}
}
}
inp = common_tokenize(ctx, final_prompt, true, true);
} else {
inp = common_tokenize(ctx, prompt, true, true);
}
if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch);
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
inputs.push_back(inp);
}
// check if the last token is SEP
// check if the last token is SEP/EOS
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
}
}
+1 -1
View File
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
+12
View File
@@ -489,6 +489,7 @@ extern "C" {
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_ROLL,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
@@ -1801,6 +1802,17 @@ extern "C" {
int p0,
int p1);
// Move tensor elements by an offset given for each dimension. Elements that
// are shifted beyond the last position are wrapped around to the beginning.
GGML_API struct ggml_tensor * ggml_roll(
struct ggml_context * ctx,
struct ggml_tensor * a,
int shift0,
int shift1,
int shift2,
int shift3);
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
+17
View File
@@ -286,6 +286,10 @@ function(ggml_add_cpu_backend_variant tag_name)
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
foreach (feat ${ARGN})
set(GGML_INTERNAL_${feat} ON)
endforeach()
endif()
ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -337,6 +341,19 @@ if (GGML_CPU_ALL_VARIANTS)
else()
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
endif()
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
ggml_add_cpu_backend_variant(power0)
ggml_add_cpu_backend_variant(power7_1 POWER7)
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
ggml_add_cpu_backend_variant(power8_1 POWER8)
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
else()
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
endif()
else()
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
endif()
+5
View File
@@ -69,6 +69,9 @@
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
namespace fs = std::filesystem;
@@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
#ifdef _WIN32
+23 -2
View File
@@ -388,6 +388,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
else()
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
endif()
elseif(GGML_CPU_ALL_VARIANTS)
# Begin with the lowest baseline
set(ARCH_DEFINITIONS "")
# When a feature is selected, bump the MCPU to the first
# version that supported it
foreach(PVER RANGE 7 11)
if(DEFINED GGML_INTERNAL_POWER${PVER})
set(POWERPC_MCPU "power${PVER}")
list(APPEND ARCH_DEFINITIONS GGML_USE_POWER${PVER})
endif()
endforeach()
if (GGML_INTERNAL_VSX)
list(APPEND ARCH_DEFINITIONS GGML_USE_VSX)
list(APPEND ARCH_FLAGS -mvsx)
endif()
if (DEFINED POWERPC_MCPU)
list(APPEND ARCH_FLAGS -mcpu=${POWERPC_MCPU})
endif()
ggml_add_cpu_backend_features(${GGML_CPU_NAME} powerpc ${ARCH_DEFINITIONS})
else()
if (GGML_CPU_POWERPC_CPUTYPE)
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
@@ -465,9 +486,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,82 @@
# include "ggml-backend-impl.h"
#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
#if defined(__linux__)
#include <sys/auxv.h>
#endif
#include <string>
struct powerpc_features {
std::string platform = "";
int power_version = -1;
bool has_vsx = false;
powerpc_features() {
#if defined(__linux__)
unsigned long auxval = getauxval(AT_PLATFORM);
if (auxval) {
platform = std::string(reinterpret_cast<const char*>(auxval));
// TBD: Do systems exist that return this in uppercase?
if (platform.substr(0, 5) == "power") {
// Extractt a numeric suffix, if one exists
int vpos = -1;
for (int i = platform.length() - 1; i >= 0; i--) {
if (std::isdigit(platform[i])) {
vpos = i;
} else {
break;
}
}
if (vpos > -1) {
power_version = std::stoi(platform.substr(vpos));
}
}
}
#endif
if (power_version >= 9) {
has_vsx = true;
}
}
};
static int ggml_backend_cpu_powerpc_score() {
int score = 1;
powerpc_features pf;
// Platform scores
#if defined(GGML_USE_POWER7)
if (pf.power_version < 7) { return 0; }
score += 1<<1;
#endif
#if defined(GGML_USE_POWER8)
if (pf.power_version < 8) { return 0; }
score += 1<<2;
#endif
#if defined(GGML_USE_POWER9)
if (pf.power_version < 9) { return 0; }
score += 1<<3;
#endif
#if defined(GGML_USE_POWER10)
if (pf.power_version < 10) { return 0; }
score += 1<<4;
#endif
#if defined(GGML_USE_POWER11)
if (pf.power_version < 11) { return 0; }
score += 1<<5;
#endif
// Feature scores
#if defined(GGML_USE_VSX)
if (!pf.has_vsx) { return 0; }
score += 1<<6;
#endif
return score;
}
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score)
#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
+4 -4
View File
@@ -371,7 +371,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
#endif
typedef signed char char8x16_t __attribute__((vector_size(16)));
typedef signed char char8x16_t __attribute__((vector_size(16)));
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
typedef int8_t int8x16_t __attribute__((vector_size(16)));
@@ -382,10 +382,10 @@ typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
typedef float float32x4_t __attribute__((vector_size(16)));
typedef double double64x2_t __attribute((vector_size(16)));
typedef float float32x4_t __attribute__((vector_size(16)));
typedef double double64x2_t __attribute__((vector_size(16)));
typedef signed long long long64x2_t __attribute((vector_size(16)));
typedef signed long long long64x2_t __attribute__((vector_size(16)));
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
typedef struct ggml_uint8x16x2_t {
+14 -86
View File
@@ -74,13 +74,8 @@
#if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type {
int has_neon;
int has_dotprod;
int has_i8mm;
int has_sve;
int sve_cnt;
int has_sme;
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
} ggml_arm_arch_features = { 0 };
#endif
@@ -678,87 +673,15 @@ bool ggml_is_numa(void) {
#if defined(__linux__) && defined(__aarch64__)
#include <sys/auxv.h>
#elif defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#if !defined(HWCAP2_I8MM)
#define HWCAP2_I8MM (1 << 13)
#endif
#if !defined(HWCAP2_SME)
#define HWCAP2_SME (1 << 23)
#endif
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__)
uint32_t hwcap = getauxval(AT_HWCAP);
uint32_t hwcap2 = getauxval(AT_HWCAP2);
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
#if defined(__ARM_FEATURE_SVE)
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
#endif
#elif defined(__APPLE__)
int oldp = 0;
size_t size = sizeof(oldp);
if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_neon = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_dotprod = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_i8mm = oldp;
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
oldp = 0;
}
ggml_arm_arch_features.has_sme = oldp;
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#else
// Run-time CPU feature detection not implemented for this platform, fallback to compile time
#if defined(__ARM_NEON)
ggml_arm_arch_features.has_neon = 1;
#else
ggml_arm_arch_features.has_neon = 0;
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
ggml_arm_arch_features.has_i8mm = 1;
#else
ggml_arm_arch_features.has_i8mm = 0;
#endif
#if defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.has_sve = 1;
ggml_arm_arch_features.sve_cnt = 16;
#else
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#endif
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
ggml_arm_arch_features.has_sme = 1;
#else
ggml_arm_arch_features.has_sme = 0;
#endif
#endif
}
#endif
#endif // __ARM_ARCH
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
GGML_ASSERT(!ggml_get_no_alloc(ctx));
@@ -1967,6 +1890,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_pad_reflect_1d(params, tensor);
} break;
case GGML_OP_ROLL:
{
ggml_compute_forward_roll(params, tensor);
} break;
case GGML_OP_ARANGE:
{
ggml_compute_forward_arange(params, tensor);
@@ -2291,6 +2218,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ROLL:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
@@ -3443,7 +3371,7 @@ int ggml_cpu_has_vxe(void) {
int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
return ggml_arm_arch_features.has_neon;
return 1;
#else
return 0;
#endif
@@ -3451,7 +3379,7 @@ int ggml_cpu_has_neon(void) {
int ggml_cpu_has_dotprod(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
return ggml_arm_arch_features.has_dotprod;
return 1;
#else
return 0;
#endif
@@ -3459,7 +3387,7 @@ int ggml_cpu_has_dotprod(void) {
int ggml_cpu_has_sve(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
return ggml_arm_arch_features.has_sve;
return 1;
#else
return 0;
#endif
@@ -3467,7 +3395,7 @@ int ggml_cpu_has_sve(void) {
int ggml_cpu_has_matmul_int8(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
return ggml_arm_arch_features.has_i8mm;
return 1;
#else
return 0;
#endif
@@ -3483,7 +3411,7 @@ int ggml_cpu_get_sve_cnt(void) {
int ggml_cpu_has_sme(void) {
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
return ggml_arm_arch_features.has_sme;
return 1;
#else
return 0;
#endif
+54 -1
View File
@@ -62,7 +62,7 @@
#define NOINLINE __attribute__((__noinline__))
#endif
#if defined(__ARM_NEON) || defined(__AVX512F__)
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
@@ -109,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__VXE__) || defined(__VXE2__)
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
#endif
#if defined(__MMA__)
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
@@ -162,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
#endif
#endif
#if defined(__VXE__) || defined(__VXE2__)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vec_madd(a, b, c);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
@@ -178,6 +191,13 @@ inline float hsum(float16x8_t x) {
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__VXE__) || defined(__VXE2__)
inline float hsum(float32x4_t x) {
float32x4_t tmp = x + vec_reve(x);
return tmp[0] + tmp[1];
}
#endif
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
@@ -227,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
#endif // _MSC_VER
#endif // __ARM_NEON
#if defined(__VXE__) || defined(__VXE2__)
template <> inline float32x4_t load(const ggml_fp16_t * p) {
float tmp[4];
for (int i = 0; i < 4; i++) {
tmp[i] = GGML_FP16_TO_FP32(p[i]);
}
return vec_xl(0, (const float *)(tmp));
}
template <> inline float32x4_t load(const float * p) {
return vec_xl(0, p);
}
#endif
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> inline __m128 load(const float *p) {
return _mm_loadu_ps(p);
@@ -3319,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(const float *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
#elif defined(__MMA__)
if (k % 8)
return false;
@@ -3410,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(float *)C, ldc};
return tb.matmul(m, n);
}
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
if (Btype == GGML_TYPE_F16) {
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
k, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
}
#endif
return false;
}
+5
View File
@@ -1,6 +1,11 @@
#pragma once
#include <stdint.h>
#include <stdbool.h>
#if defined(__VXE__) || defined(__VXE2__)
#include <vecintrin.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
+67
View File
@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
}
}
// ggml_compute_forward_roll
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
if (i < 0) {
return i + ne;
} else if (i >= ne) {
return i - ne;
}
return i;
}
static void ggml_compute_forward_roll_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src_data = (const float *) src0->data;
float * dst_data = (float *) dst->data;
GGML_TENSOR_UNARY_OP_LOCALS
const int s0 = ggml_get_op_params_i32(dst, 0);
const int s1 = ggml_get_op_params_i32(dst, 1);
const int s2 = ggml_get_op_params_i32(dst, 2);
const int s3 = ggml_get_op_params_i32(dst, 3);
const int64_t total = ne1 * ne2 * ne3;
const int64_t per_thread = (total + params->nth) / params->nth;
const int64_t start = params->ith * per_thread;
const int64_t end = std::min(start + per_thread, total);
for (int64_t i = start; i < end; ++i) {
const int64_t i1 = i % ne1;
const int64_t i2 = (i / ne1) % ne2;
const int64_t i3 = i / (ne2 * ne1);
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
const int64_t s = ggml_wrap_index(-s0, ne00);
const int64_t n = ne00 - s;
ggml_vec_cpy_f32(n, dst_row, src_row + s);
ggml_vec_cpy_f32(s, dst_row + n, src_row);
}
}
void ggml_compute_forward_roll(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_roll_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_arange
static void ggml_compute_forward_arange_f32(
+1
View File
@@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+41 -26
View File
@@ -1163,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// not realy a GGML_TYPE_Q8_0 but same size.
switch (op->op) {
case GGML_OP_MUL_MAT:
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
return true;
{
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
return true;
}
case GGML_OP_MUL_MAT_ID:
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
return true;
{
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
return true;
}
default:
// GGML_ABORT("fatal error");
break;
@@ -1305,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
int32_t i2;
};
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
n_as * ne12 * sizeof(mmid_row_mapping)));
GGML_ASSERT(params->wsize >=
(GGML_PAD(nbw3, sizeof(int64_t)) +
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
);
auto * wdata = (char *) params->wdata;
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
auto * wdata = (char *)params->wdata;
auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
// src1: float32 => param type
for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -1397,44 +1411,45 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
}
};
// instance for Q4
static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for IQ4
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
} // namespace ggml::cpu::repack
static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
// instance for Q4
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for IQ4
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
if (cur->ne[1] % 8 == 0) {
return &ggml::cpu::repack::q4_0_8x8_q8_0;
return &q4_0_8x8_q8_0;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::q4_0_4x8_q8_0;
return &q4_0_4x8_q8_0;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::q4_0_4x4_q8_0;
return &q4_0_4x4_q8_0;
}
}
} else if (cur->type == GGML_TYPE_Q4_K) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {
return &ggml::cpu::repack::q4_K_8x8_q8_K;
return &q4_K_8x8_q8_K;
}
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) {
return &ggml::cpu::repack::iq4_nl_4x4_q8_0;
return &iq4_nl_4x4_q8_0;
}
}
}
+2 -4
View File
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
for (int i = 0; i < offset; ++i) { \
x[i] = vec_add(x[i], x[offset + i]); \
} \
res = vec_extract(x[0], 0) + \
vec_extract(x[0], 1) + \
vec_extract(x[0], 2) + \
vec_extract(x[0], 3); \
float32x4_t tmp = x[0] + vec_reve(x[0]); \
res = tmp[0] + tmp[1]; \
}
#define GGML_F32_VEC GGML_F32x4
+42 -18
View File
@@ -19,10 +19,10 @@
#endif
#include "ggml-common.h"
#include <cstdio>
#include <array>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <vector>
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false;
#else
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return true;
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
return true;
#else
return false;
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
} else {
return false;
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
}
@@ -252,6 +262,14 @@ static bool fp16_mma_hardware_available(const int cc) {
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
}
static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
}
static bool fp32_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_CDNA(cc);
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
@@ -362,6 +380,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // FP16_AVAILABLE
}
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
template<bool norm>
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col != 0) {
return;
}
dst[row] = norm ? sum / ncols : sum;
}
template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
@@ -767,21 +805,7 @@ struct ggml_backend_cuda_context {
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {
+161
View File
@@ -0,0 +1,161 @@
#include "conv2d-dw.cuh"
struct conv_params {
int in_w, in_h;
int out_w, out_h;
int kernel_w, kernel_h;
int stride_x, stride_y;
int padding_x, padding_y;
int dilation_x, dilation_y;
int channels, batches;
};
struct kernel_bounds {
int y_min, y_max;
int x_min, x_max;
};
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
kernel_bounds bounds;
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.y_max =
min(params.kernel_h,
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
bounds.x_max =
min(params.kernel_w,
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
return bounds;
}
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
return out_coord * stride + kern_coord * dilation - padding;
}
struct whcn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
y * params.out_w + x;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
out_x = global_idx % params.out_w;
out_y = (global_idx / params.out_w) % params.out_h;
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
n = global_idx / (params.out_w * params.out_h * params.channels);
}
};
struct cwhn_layout {
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
}
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
return (ky * params.kernel_w + kx) * params.channels + c;
}
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
x * params.channels + c;
}
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
int & out_x) {
c = global_idx % params.channels;
out_x = (global_idx / params.channels) % params.out_w;
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
n = global_idx / (params.channels * params.out_w * params.out_h);
}
};
template <typename T, typename Layout>
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
const int in_w, const int in_h, const int out_w, const int out_h,
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
const int channels, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = batches * channels * out_h * out_w;
if (global_idx >= total_elements) {
return;
}
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
int batch_idx, channel_idx, out_y_idx, out_x_idx;
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
T accumulator = 0;
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
accumulator += input_val * kernel_val;
}
}
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
}
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * w_d = (const float *) kernel->data;
const float * x_d = (const float *) input->data;
float * y_d = (float *) dst->data;
const int32_t * p = (const int32_t *) dst->op_params;
const int stride_x = p[0];
const int stride_y = p[1];
const int padding_x = p[2];
const int padding_y = p[3];
const int dilation_x = p[4];
const int dilation_y = p[5];
const int in_w = input->ne[0];
const int in_h = input->ne[1];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int out_w = dst->ne[0];
const int out_h = dst->ne[1];
const int channels = dst->ne[2];
const int batches = dst->ne[3];
cudaStream_t st = ctx.stream();
const int total = batches * channels * out_h * out_w;
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
if (ggml_is_contiguous(input)) {
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else if (ggml_is_contiguous_channels(input)) {
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
dilation_x, dilation_y, channels, batches);
} else {
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
}
}
+5
View File
@@ -0,0 +1,5 @@
#pragma once
#include "common.cuh"
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+91
View File
@@ -0,0 +1,91 @@
#include <algorithm>
#include "conv2d-transpose.cuh"
#include "ggml.h"
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
const int out_h, const int kernel_w, const int kernel_h, const int stride,
const int c_in, const int c_out, const int batches) {
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int total_elements = out_w * out_h * c_out * batches;
if (global_idx >= total_elements) {
return;
}
const int out_x_idx = global_idx % out_w;
const int out_y_idx = (global_idx / out_w) % out_h;
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
const int n_idx = global_idx / (out_w * out_h * c_out);
float accumulator = 0;
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
for (int kh = 0; kh < kernel_h; ++kh) {
int in_y = out_y_idx - kh;
if (in_y < 0 || in_y % stride) continue;
in_y /= stride;
if (in_y >= in_h) continue;
for (int kw = 0; kw < kernel_w; ++kw) {
int in_x = out_x_idx - kw;
if (in_x < 0 || in_x % stride) continue;
in_x /= stride;
if (in_x >= in_w) continue;
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
const int kernel_idx =
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
float input_val = input[input_idx];
half kern_val = kernel[kernel_idx];
accumulator += input_val * (float) kern_val;
}
}
}
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
}
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * kernel = dst->src[0];
const ggml_tensor * input = dst->src[1];
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
const float * input_data = (const float *) input->data;
float * output_data = (float *) dst->data;
const half * kernel_data = (const half *) kernel->data;
const int input_w = input->ne[0];
const int input_h = input->ne[1];
const int output_w = dst->ne[0];
const int output_h = dst->ne[1];
const int channels_in = input->ne[2];
const int channels_out = kernel->ne[2];
const int kernel_w = kernel->ne[0];
const int kernel_h = kernel->ne[1];
const int stride = dst->op_params[0];
const int batches = input->ne[3];
GGML_ASSERT(channels_in == kernel->ne[3]);
GGML_ASSERT(stride > 0);
cudaStream_t st = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(input));
GGML_ASSERT(ggml_is_contiguous(kernel));
GGML_ASSERT(ggml_is_contiguous(dst));
const int total = (output_w * output_h * channels_out * batches);
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
channels_in, channels_out, batches);
}
+4
View File
@@ -0,0 +1,4 @@
#include "common.cuh"
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+69 -21
View File
@@ -11,6 +11,8 @@
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
@@ -35,6 +37,7 @@
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
@@ -47,6 +50,7 @@
#include <atomic>
#include <charconv>
#include <cinttypes>
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <float.h>
@@ -54,9 +58,8 @@
#include <map>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <stdio.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <vector>
@@ -97,8 +100,7 @@ int ggml_cuda_get_device() {
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
{
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
err = cudaMallocManaged(ptr, size);
#if defined(GGML_USE_HIP)
if (err == hipSuccess) {
@@ -116,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
err = cudaMalloc(ptr, size);
}
#endif // defined(GGML_USE_HIP)
}
else
{
} else {
err = cudaMalloc(ptr, size);
}
return err;
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
static std::mutex ggml_cuda_lock;
static std::condition_variable ggml_cuda_lock_cv;
static std::atomic<int> ggml_cuda_lock_counter;
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
// cuda buffer
struct ggml_backend_cuda_buffer_context {
@@ -1916,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
bool any_gpus_with_slow_fp16 = false;
bool any_gpus_without_fp16_mma = false;
bool any_gpus_with_slow_fp16 = false;
if (split) {
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1936,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
continue;
}
const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
const int cc = ggml_cuda_info().devices[id].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
} else {
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
const int cc = ggml_cuda_info().devices[ctx.device].cc;
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
// debug helpers
@@ -1956,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
if (!split && use_mul_mat_vec) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
@@ -2310,6 +2335,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
@@ -2322,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_MEAN:
ggml_cuda_op_mean(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
@@ -2685,6 +2719,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
}
@@ -2760,7 +2799,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
}
}
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
@@ -3209,9 +3254,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
+19
View File
@@ -0,0 +1,19 @@
#include "mean.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
+3
View File
@@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+257 -87
View File
@@ -2,25 +2,26 @@
#include "common.cuh"
#include "mmv.cuh"
template <typename T, typename type_acc, int block_size>
template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
const int64_t channel_dst = blockIdx.y;
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int64_t sample_dst = blockIdx.z;
const int64_t sample_x = sample_dst / sample_ratio;
const int64_t sample_y = sample_dst;
const int tid = threadIdx.x;
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x;
const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
__syncthreads();
}
float sumf = 0.0f;
float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) {
const float2 * x2 = (const float2 *) x;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2];
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x;
sumf += tmpx.y*tmpy.y;
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x*tmpy.x;
sumf[j] += tmpx.y*tmpy.y;
}
}
} else if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
const float2 tmpy = y2[col2];
sumf += tmpx.x * tmpy.x;
sumf += tmpx.y * tmpy.y;
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += tmpx.x * tmpy.x;
sumf[j] += tmpx.y * tmpy.y;
}
}
} else {
#ifdef FP16_AVAILABLE
half2 sumh2 = make_half2(0.0f, 0.0f);
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmp = y2[col2];
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const half2 tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
}
}
sumf = __low2float(sumh2) + __high2float(sumh2);
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
}
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
const int * x2 = (const int *) x;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
const float2 tmpy = y2[col2];
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
}
}
} else {
static_assert(std::is_same<T, void>::value, "unsupported type");
}
sumf = warp_reduce_sum<warp_size>(sumf);
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf;
__syncthreads();
if (tid >= warp_size) {
return;
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf[j];
__syncthreads();
if (tid < warp_size) {
sumf[j] = buf_iw[tid];
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
}
if (j < ncols_dst) {
__syncthreads();
}
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum<warp_size>(sumf);
}
if (tid != 0) {
if (tid >= ncols_dst) {
return;
}
dst[row] = sumf;
dst[tid*stride_col_dst + row] = sumf[tid];
}
template <typename T, typename type_acc>
template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
}
}
template <typename T, typename type_acc>
static void mul_mat_vec_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 2:
launch_mul_mat_vec_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 3:
launch_mul_mat_vec_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 4:
launch_mul_mat_vec_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 5:
launch_mul_mat_vec_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 6:
launch_mul_mat_vec_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 7:
launch_mul_mat_vec_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
case 8:
launch_mul_mat_vec_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
template<typename T>
static void mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) {
if (prec == GGML_PREC_DEFAULT) {
launch_mul_mat_vec_cuda<T, half>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
mul_mat_vec_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
launch_mul_mat_vec_cuda<T, float>
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
mul_mat_vec_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(ncols_dst == 1);
GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1_ncols == 1);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00;
const int64_t stride_col_y = ne10;
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1;
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
@@ -334,3 +441,66 @@ void ggml_cuda_op_mul_mat_vec(
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
if (src0_ne[0] % 2 != 0) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return ne11 <= 8;
}
if (cc >= GGML_CUDA_CC_TURING) {
return ne11 <= 4;
}
return ne11 <= 3;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (fp32_mma_hardware_available(cc)) {
return ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
case GGML_TYPE_F16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (fp16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (fp16_mma_hardware_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
return ne11 <= 5;
}
return ne11 <= 2;
}
return ne11 <= 8;
}
return ne11 <= 8;
case GGML_TYPE_BF16:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return src0_small && ne11 <= 4;
}
if (bf16_mma_hardware_available(cc)) {
return src0_small && ne11 <= 3;
}
return ne11 <= 8;
} else if (GGML_CUDA_CC_IS_AMD(cc)) {
if (bf16_mma_hardware_available(cc)) {
return ne11 <= 3;
}
return ne11 <= 8;
}
return ne11 <= 8;
default:
return false;
}
}
+2 -3
View File
@@ -1,8 +1,5 @@
#include "common.cuh"
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
#define MMV_MAX_ROWS 512
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec(
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
+5 -18
View File
@@ -1,25 +1,9 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
}
-1
View File
@@ -1,5 +1,4 @@
#include "common.cuh"
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+88 -33
View File
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
int mtl_device_ref_count;
id<MTLLibrary> mtl_library;
NSLock * mtl_lock;
bool has_simdgroup_reduction;
bool has_simdgroup_mm;
bool has_residency_sets;
bool has_bfloat;
bool use_bfloat;
size_t max_size;
char name[128];
} g_ggml_ctx_dev_main = {
/*.mtl_device =*/ nil,
/*.mtl_device_ref_count =*/ 0,
/*.mtl_library =*/ nil,
/*.mtl_lock =*/ nil,
/*.has_simdgroup_reduction =*/ false,
/*.has_simdgroup_mm =*/ false,
/*.has_residency_sets =*/ false,
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.max_size =*/ 0,
/*.name =*/ "",
};
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
assert(ctx != NULL);
if (ctx->mtl_lock == nil) {
ctx->mtl_lock = [[NSLock alloc] init];
}
if (ctx->mtl_device == nil) {
ctx->mtl_device = MTLCreateSystemDefaultDevice();
}
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
ctx->use_bfloat = false;
#endif
ctx->max_size = ctx->mtl_device.maxBufferLength;
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
}
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
ctx->mtl_device_ref_count--;
if (ctx->mtl_device_ref_count == 0) {
if (ctx->mtl_lock) {
[ctx->mtl_lock release];
ctx->mtl_lock = nil;
}
if (ctx->mtl_library) {
[ctx->mtl_library release];
ctx->mtl_library = nil;
@@ -498,6 +515,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
GGML_METAL_KERNEL_TYPE_ARGMAX,
@@ -976,7 +994,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
id<MTLDevice> device = ctx_dev->mtl_device;
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
@@ -990,9 +1008,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
// load library
if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
{
[ctx_dev->mtl_lock lock];
if (ctx_dev->mtl_library == nil) {
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
}
[ctx_dev->mtl_lock unlock];
}
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
if (metal_library == nil) {
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
@@ -1454,6 +1479,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
@@ -1653,6 +1679,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_LOG:
return false; // TODO: implement
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
@@ -2400,11 +2427,30 @@ static bool ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
{
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
id<MTLComputePipelineState> pipeline = nil;
switch (dst->op) {
case GGML_OP_SUM_ROWS:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
break;
case GGML_OP_MEAN:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
nth *= 2;
}
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
@@ -2434,11 +2480,12 @@ static bool ggml_metal_encode_node(
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
@@ -5261,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
}
ggml_backend_metal_buffer_rset_free(ctx);
ggml_backend_metal_device_rel(buffer->buft->device->context);
if (ctx->owned) {
#if TARGET_OS_OSX
@@ -5370,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
}
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
ctx->all_data = ggml_metal_host_malloc(size_aligned);
ctx->all_size = size_aligned;
@@ -5393,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
@@ -5411,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return 32;
GGML_UNUSED(buft);
}
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
const size_t max_size = device.maxBufferLength;
ggml_backend_metal_device_rel(buft->device->context);
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
return max_size;
GGML_UNUSED(buft);
}
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
@@ -5494,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
}
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) {
@@ -5550,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
@@ -5566,10 +5612,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
}
static void ggml_backend_metal_free(ggml_backend_t backend) {
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
struct ggml_backend_metal_context * ctx = backend->context;
ggml_backend_metal_device_rel(ctx_dev);
ggml_metal_free(ctx);
free(backend);
@@ -5709,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
GGML_ASSERT(ctx_dev->mtl_device != nil);
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
}
@@ -5728,10 +5774,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
}
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
// acq/rel just to populate ctx->name in case it hasn't been done yet
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
ggml_backend_metal_device_acq(ctx_dev);
ggml_backend_metal_device_rel(ctx_dev);
return ctx_dev->name;
}
@@ -5739,12 +5782,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
if (@available(macOS 10.12, iOS 16.0, *)) {
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
id<MTLDevice> device = ctx_dev->mtl_device;
*total = device.recommendedMaxWorkingSetSize;
*free = *total - device.currentAllocatedSize;
ggml_backend_metal_device_rel(ctx_dev);
} else {
*free = 1;
*total = 1;
@@ -5822,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
}
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_ASSERT(ctx_dev->mtl_device != nil);
id<MTLDevice> device = ctx_dev->mtl_device;
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= device.maxBufferLength) {
@@ -5878,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
free(ctx);
ggml_backend_metal_device_rel(ctx_dev);
return NULL;
}
@@ -5892,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
}
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
return
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
GGML_UNUSED(dev);
}
@@ -5978,8 +6022,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
};
// called upon program exit
static void ggml_metal_cleanup(void) {
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
}
// TODO: make thread-safe
ggml_backend_reg_t ggml_backend_metal_reg(void) {
// TODO: make this thread-safe somehow?
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
// register cleanup callback
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
atexit(ggml_metal_cleanup);
{
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
/* .api_version = */ GGML_BACKEND_API_VERSION,
+39 -9
View File
@@ -993,31 +993,61 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig];
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
device const float * src0,
device float * dst,
constant ggml_metal_kargs_sum_rows & args,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
int64_t i1 = tpig.x;
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z;
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float row_sum = 0;
float sumf = 0;
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
row_sum += src_row[i0];
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0];
}
dst_row[0] = row_sum;
sumf = simd_sum(sumf);
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg];
sumf = simd_sum(sumf);
if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf;
}
}
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
template<typename T>
kernel void kernel_soft_max(
device const char * src0,
File diff suppressed because it is too large Load Diff
+5 -6
View File
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
sycl::range<3>(1, 1, block_size),
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
@@ -246,9 +246,8 @@ struct bin_bcast_sycl {
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
+1 -24
View File
@@ -199,7 +199,7 @@ struct sycl_device_info {
// size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t total_vram;
sycl_hw_info hw_info;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
optimize_feature opt_feature;
};
@@ -286,29 +286,6 @@ struct ggml_tensor_extra_gpu {
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
inline optimize_feature check_gpu_optimize_feature(syclex::architecture &arch) {
optimize_feature opt;
opt.reorder =
(arch == syclex::architecture::intel_gpu_dg1 ||
arch == syclex::architecture::intel_gpu_acm_g10 ||
arch == syclex::architecture::intel_gpu_acm_g11 ||
arch == syclex::architecture::intel_gpu_acm_g12 ||
arch == syclex::architecture::intel_gpu_pvc ||
arch == syclex::architecture::intel_gpu_pvc_vg ||
arch == syclex::architecture::intel_gpu_mtl_u ||
arch == syclex::architecture::intel_gpu_mtl_s ||
arch == syclex::architecture::intel_gpu_mtl_h ||
arch == syclex::architecture::intel_gpu_arl_u ||
arch == syclex::architecture::intel_gpu_arl_s ||
arch == syclex::architecture::intel_gpu_arl_h ||
arch == syclex::architecture::intel_gpu_bmg_g21 ||
arch == syclex::architecture::intel_gpu_lnl_m
);
return opt;
}
namespace sycl_ex = sycl::ext::oneapi::experimental;
struct ggml_backend_sycl_context {
int device;
+28 -41
View File
@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
sycl::range<3> gridDim(ne2, ne1, num_blocks);
switch (dim) {
case 0:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
});
break;
sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
break;
case 1:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
});
break;
sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
break;
// dim >=2 will be dispatched to the default path
default:
stream->parallel_for(
sycl::nd_range<3>(gridDim *
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
});
break;
sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
break;
}
}
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
uint64_t nb3, int32_t dim) {
sycl::range<3> gridDim(ne3, ne2, ne1);
stream->parallel_for(
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
int64_t i3 = item_ct1.get_group(0);
int64_t i2 = item_ct1.get_group(1);
int64_t i1 = item_ct1.get_group(2);
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
int64_t i3 = item_ct1.get_group(0);
int64_t i2 = item_ct1.get_group(1);
int64_t i1 = item_ct1.get_group(2);
int64_t o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
int64_t o[4] = { 0, 0, 0, 0 };
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
const float *x;
const float * x;
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
i0 += item_ct1.get_local_range(2)) {
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
(i0)*nb00);
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
} else {
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
(i0 - o[0]) * nb10);
}
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
*y = *x;
}
});
}
});
}
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+4 -10
View File
@@ -59,16 +59,10 @@ static void conv_transpose_1d_f32_f32_sycl(
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
const sycl::range<3> block_nums(1, 1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(
block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
conv_transpose_1d_kernel(
s0, output_size,
src0_ne0, src0_ne1, src0_ne2,
src1_ne0, dst_ne0,
src0, src1, dst, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
item_ct1);
});
}
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+99 -166
View File
@@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(
sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
});
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
}
}
@@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q2_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
}
#endif
@@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q3_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
}
#endif
}
@@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_0(vx, y, nb32, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
}
}
@@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
});
}
template <typename dst_t>
@@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_1(vx, y, nb32, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
}
}
@@ -171,14 +149,13 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
});
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
});
});
}
}
@@ -191,13 +168,13 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->submit([&](sycl::handler & cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
[=](sycl::nd_item<1> item_ct1) {
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
});
});
}
@@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q5_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
}
#endif
@@ -242,24 +213,18 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 64),
sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
}
#else
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q6_K(vx, y, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
}
#endif
@@ -271,9 +236,9 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
}
template <typename dst_t>
@@ -284,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_s(
vx, y, item_ct1, iq1s_grid_gpu
);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
});
}
}
@@ -305,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq1_m(
vx, y, item_ct1, iq1s_grid_gpu
);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
});
}
}
@@ -326,15 +281,12 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs(
vx, y, item_ct1, iq2xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -347,15 +299,12 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs(
vx, y, item_ct1, iq2xs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -368,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq2_s(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
});
}
}
@@ -388,15 +334,12 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs(
vx, y, item_ct1, iq3xxs_grid,
ksigns_iq2xs, kmask_iq2xs);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
});
});
}
}
@@ -409,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq3_s(
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
});
}
}
@@ -432,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_xs(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
});
}
#endif
@@ -453,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_iq4_nl(vx, y, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(
cgh,
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
});
}
}
+94 -72
View File
@@ -413,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -431,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -449,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -465,11 +468,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -477,11 +480,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -490,11 +493,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -502,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
@@ -516,11 +520,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -528,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
@@ -542,11 +547,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -554,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
@@ -568,11 +574,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -580,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ne;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
item_ct1);
@@ -594,11 +601,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
const int nb12, const int nb13, queue_ptr stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
[=](sycl::nd_item<3> item_ct1) {
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
@@ -609,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
{
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -628,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16});
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -647,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
// dpct::has_capability_or_fail(stream->get_device(),
// {sycl::aspect::fp16});
stream->parallel_for(
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
@@ -662,11 +672,13 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
@@ -675,11 +687,13 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
@@ -689,11 +703,13 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
@@ -702,10 +718,13 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
@@ -715,10 +734,13 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
const int nb12, const int nb13, queue_ptr stream) {
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
ne12, nb10, nb11, nb12, nb13, item_ct1);
});
}
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
+49 -67
View File
@@ -208,12 +208,10 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -877,12 +875,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
nrows, item_ct1);
});
}
}
@@ -900,12 +897,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -921,12 +916,10 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -942,12 +935,10 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -963,12 +954,10 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -984,12 +973,10 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
});
}
}
@@ -1002,11 +989,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
@@ -1018,11 +1004,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
@@ -1034,11 +1019,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
@@ -1047,11 +1031,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
});
}
static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
@@ -1063,11 +1046,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
});
}
void ggml_sycl_op_dequantize_mul_mat_vec(
+31 -1
View File
@@ -13,10 +13,10 @@
#ifndef GGML_SYCL_DPCT_HELPER_HPP
#define GGML_SYCL_DPCT_HELPER_HPP
#include <map>
#include <sycl/sycl.hpp>
#include <sycl/half_type.hpp>
#include <syclcompat/math.hpp>
#include <map>
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
#include <oneapi/mkl.hpp>
@@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
#endif
}
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
namespace syclex = sycl::ext::oneapi::experimental;
#endif
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(cgh, nd_range, func);
#else
cgh.parallel_for(nd_range, func);
#endif
}
template <int NR, typename Func>
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::nd_launch(*q, nd_range, func);
#else
q->parallel_for(nd_range, func);
#endif
}
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
syclex::submit(*stream, func);
#else
stream->submit(func);
#endif
}
namespace dpct
{
typedef sycl::queue *queue_ptr;
+99 -159
View File
@@ -329,60 +329,51 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int ne12, const int nb1, const int nb2,
const int offset, queue_ptr stream) {
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
});
}
template<typename T>
static void gelu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
gelu(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
}
template<typename T>
static void silu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
silu(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
}
template<typename T>
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now
const int num_blocks = ceil_div(k, 256);
stream->parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
sgn(x, dst, k, item_ct1);
});
sycl_parallel_for(
stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
[=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
}
template<typename T>
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now
const int num_blocks = ceil_div(k, 256);
stream->parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
abs_op(x, dst, k, item_ct1);
});
sycl_parallel_for(
stream,
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
[=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
}
@@ -390,23 +381,20 @@ template<typename T>
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
// hard code for now
const int num_blocks = ceil_div(k, 256);
stream->parallel_for(
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
elu_op(x, dst, k, item_ct1);
});
sycl_parallel_for(
stream,
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
[=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
}
template<typename T>
static void gelu_quick_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
gelu_quick(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
}
@@ -414,169 +402,133 @@ template<typename T>
static void gelu_erf_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
gelu_erf(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
}
template<typename T>
static void tanh_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
tanh(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
}
template<typename T>
static void relu_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
relu(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
}
template<typename T>
static void hardsigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
hardsigmoid(x, dst, k, item_ct1);
});
[=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
}
template<typename T>
static void hardswish_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
hardswish(x, dst, k, item_ct1);
});
[=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
}
template<typename T>
static void exp_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
exp(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
}
template<typename T>
static void log_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
log(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
}
template<typename T>
static void neg_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
neg(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
}
template<typename T>
static void step_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
step(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
}
template<typename T>
static void sigmoid_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
sycl_parallel_for(
stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sigmoid(x, dst, k, item_ct1);
});
[=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
}
template<typename T>
static void sqrt_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sqrt(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
}
template<typename T>
static void sin_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sin(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
}
template<typename T>
static void cos_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
cos(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
}
template<typename T>
@@ -584,26 +536,20 @@ static void leaky_relu_sycl(const T *x, T *dst, const int k,
const float negative_slope,
queue_ptr stream) {
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
leaky_relu(x, dst, k, negative_slope, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
}
template<typename T>
static void sqr_sycl(const T *x, T *dst, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
sqr(x, dst, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
}
template<typename T>
@@ -614,9 +560,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
int dst_size = ne10 * ne11 * ne12 * ne13;
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
stream->parallel_for(
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
[=](sycl::nd_item<1> item_ct1) {
sycl_parallel_for<1>(
stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
});
}
@@ -627,12 +572,10 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
const int ne1, const int ne2, queue_ptr stream) {
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
sycl::range<3> gridDim(ne2, ne1, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
}
template<typename T>
@@ -640,13 +583,10 @@ static void clamp_sycl(const T *x, T *dst, const float min,
const float max, const int k,
queue_ptr stream) {
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
clamp(x, dst, min, max, k, item_ct1);
});
sycl_parallel_for(stream,
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
}
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+8 -105
View File
@@ -60,54 +60,6 @@ static void k_get_rows(
dst_row[iybs + iqs + y_offset] = v.y();
}
template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
static void k_get_rows_reorder(
const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12,
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(2)) *
2;
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1);
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
item_ct1.get_local_id(0)) /
ne12;
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
item_ct1.get_local_id(0)) %
ne12;
if (i00 >= ne00) {
return;
}
auto ncols = ne00;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const int src0_off = i01 * ncols + i00;
const int ib = src0_off / QK4_0; // block index
const int iqs = (i00%qk)/qr; // x quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
dst_row[iybs + iqs + 0] = v.x();
dst_row[iybs + iqs + y_offset] = v.y();
GGML_UNUSED(nb01);
GGML_UNUSED(nb02);
GGML_UNUSED(nb03);
}
template<typename src0_t, typename dst_t>
static void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst,
@@ -166,58 +118,15 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
GGML_ASSERT(ne00 % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_get_rows<qk, qr, dq>(
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
item_ct1);
});
GGML_UNUSED(dst);
GGML_UNUSED(ctx);
}
template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const void *src0_dd,
const int32_t *src1_dd, float *dst_dd,
queue_ptr stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
GGML_ASSERT(ne00 % 2 == 0);
const uint8_t* src0_q = (const uint8_t*)src0_dd;
const size_t ncols = ne00;
const size_t nrows = ne01;
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
k_get_rows_reorder<qk, qr, dq_reorder>(
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
});
GGML_UNUSED(dst);
GGML_UNUSED(ctx);
}
template <typename src0_t>
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
@@ -245,9 +154,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
});
@@ -277,13 +185,8 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q4_0:
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
} else {
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
}
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q4_1:
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
+45 -54
View File
@@ -83,9 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].hw_info = get_device_hw_info(&device);
info.devices[i].opt_feature = check_gpu_optimize_feature(info.devices[i].hw_info.arch);
info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
}
@@ -195,7 +193,7 @@ static void ggml_check_sycl() try {
if (!initialized) {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
@@ -1887,13 +1885,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
const size_t shared_mem = ncols_pad * sizeof(int);
if (order == GGML_SORT_ORDER_ASC) {
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1901,13 +1898,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
});
});
} else if (order == GGML_SORT_ORDER_DESC) {
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
sycl::range<1>(shared_mem), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
x, dst, ncols, ncols_pad, item_ct1,
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
@@ -1925,50 +1921,47 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
const sycl::range<3> block_nums(1, nrows, 1);
const size_t shared_mem = 256 * sizeof(float);
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_data(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
sycl::local_accessor<int, 1> shared_indices(
sycl::range<1>(shared_mem/sizeof(float)), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1);
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
const int tid = item_ct1.get_local_id(2);
const int row = item_ct1.get_global_id(1);
float max_val = -INFINITY;
int max_idx = -1;
float max_val = -INFINITY;
int max_idx = -1;
for (int col = tid; col < ncols; col += 256) {
float val = x[row * ncols + col];
if (val > max_val) {
max_val = val;
max_idx = col;
for (int col = tid; col < ncols; col += 256) {
float val = x[row * ncols + col];
if (val > max_val) {
max_val = val;
max_idx = col;
}
}
shared_data[tid] = max_val;
shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space);
for (int stride = 256 / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
float val1 = shared_data[tid];
float val2 = shared_data[tid + stride];
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
}
}
shared_data[tid] = max_val;
shared_indices[tid] = max_idx;
item_ct1.barrier(sycl::access::fence_space::local_space);
}
for (int stride = 256/2; stride > 0; stride >>= 1) {
if (tid < stride) {
float val1 = shared_data[tid];
float val2 = shared_data[tid + stride];
if (val2 > val1) {
shared_data[tid] = val2;
shared_indices[tid] = shared_indices[tid + stride];
}
}
item_ct1.barrier(sycl::access::fence_space::local_space);
}
if (tid == 0) {
dst[row] = shared_indices[0];
}
});
if (tid == 0) {
dst[row] = shared_indices[0];
}
});
});
}
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
@@ -2952,7 +2945,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
@@ -3456,7 +3449,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh);
char *__restrict src1_contiguous_get =
@@ -3468,9 +3461,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
size_t ids_nb_ct6 = ids->nb[1];
size_t ids_nb_ct7 = ids->nb[0];
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_copy_src1_to_contiguous(
src1_original, src1_contiguous_get,
dev_cur_src1_row_get,
@@ -3501,15 +3493,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
{
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
sycl::range<3> grid_dims(1, 1, num_src1_rows);
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
const char *__restrict dst_contiguous_get =
dst_contiguous.get();
const mmid_row_mapping *__restrict dev_row_mapping_get =
dev_row_mapping.get();
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_copy_dst_from_contiguous(dst_original,
dst_contiguous_get,
dev_row_mapping_get,
+2 -2
View File
@@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
const u_int n_seq_tokens = T / B;
sycl::range<1> block_dims((C / H));
sycl::range<1> grid_dims((B * H));
stream->submit([&](sycl::handler & cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
/* local memory accessors*/
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
u_int tid = item.get_local_id(0);
u_int bid = item.get_group(0);
+1 -1
View File
@@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
const int64_t CHW = IC * KH * KW;
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
p0, p1, d0, d1, item_ct1);
});
+60 -80
View File
@@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@@ -1829,9 +1829,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -1853,7 +1852,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
@@ -1864,9 +1863,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -1933,7 +1931,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@@ -1944,9 +1942,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -1968,7 +1965,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
@@ -1979,9 +1976,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2048,7 +2044,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@@ -2059,9 +2055,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2083,7 +2078,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
@@ -2094,9 +2089,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2163,7 +2157,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@@ -2174,9 +2168,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2198,7 +2191,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
@@ -2209,9 +2202,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_1<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2278,7 +2270,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@@ -2289,9 +2281,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2313,7 +2304,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
@@ -2324,9 +2315,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q8_0<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2393,7 +2383,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@@ -2406,9 +2396,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2431,7 +2420,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
@@ -2444,9 +2433,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q2_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2516,7 +2504,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@@ -2531,9 +2519,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2557,7 +2544,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
@@ -2572,9 +2559,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q3_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2644,7 +2630,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@@ -2657,9 +2643,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2682,7 +2667,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
@@ -2695,9 +2680,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q4_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2765,7 +2749,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@@ -2778,9 +2762,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2803,7 +2786,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
@@ -2816,9 +2799,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q5_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2886,7 +2868,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@@ -2899,9 +2881,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
@@ -2924,7 +2905,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
@@ -2937,9 +2918,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
mul_mat_q6_K<need_check>(
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
nrows_dst, item_ct1,
+132 -201
View File
@@ -544,12 +544,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
nd_item);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
@@ -561,12 +561,12 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -580,17 +580,12 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -604,17 +599,12 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -628,17 +618,12 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -652,17 +637,12 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -676,17 +656,12 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -700,17 +675,12 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -724,17 +694,12 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -750,12 +715,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
nrows, nd_item);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
@@ -769,17 +734,12 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -794,12 +754,12 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
@@ -811,17 +771,12 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
});
}
}
@@ -836,14 +791,12 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
nrows, item_ct1);
});
});
}
}
@@ -857,14 +810,12 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
nrows, item_ct1);
});
});
}
}
@@ -878,15 +829,12 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
item_ct1);
});
});
}
}
@@ -900,15 +848,12 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
nrows, item_ct1);
});
});
}
}
@@ -922,15 +867,12 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
item_ct1);
});
});
}
}
@@ -944,15 +886,12 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
item_ct1);
});
});
}
}
@@ -966,14 +905,12 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
item_ct1);
});
});
}
}
@@ -987,15 +924,12 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
item_ct1);
});
});
}
}
@@ -1009,15 +943,12 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
{
stream->submit([&](sycl::handler &cgh) {
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
vx, vy, dst, ncols, nrows, item_ct1);
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
nrows, item_ct1);
});
});
}
}
+55 -74
View File
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
const int ne_elements, queue_ptr stream, int device) {
if (group_size < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
const float eps_ct4 = eps;
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
nullptr, WARP_SIZE);
});
});
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
const float eps_ct4 = eps;
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements,
eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
});
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
}
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
if (ncols < 1024) {
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
nullptr, WARP_SIZE);
});
});
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
else {
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
cgh);
cgh.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1,
get_pointer(s_sum_acc_ct1), work_group_size);
});
});
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
work_group_size);
});
});
}
}
+24 -20
View File
@@ -235,20 +235,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
});
} else {
/*
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
the limit. To get the device limit, query
info::device::max_work_group_size. Adjust the work-group size if needed.
*/
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
});
}
}
@@ -267,15 +269,17 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
if (freq_factors == nullptr) {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
});
} else {
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, item_ct1);
});
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
});
}
}
@@ -298,12 +302,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
@@ -333,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
}
// launch kernel
if (freq_factors == nullptr) {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
} else {
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
corr_dims, theta_scale, freq_factors, sections, item_ct1);
});
+3 -3
View File
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
const int nrows_y, const float scale, const float max_bias, const float m0,
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
const size_t n_local_scratch, queue_ptr stream) {
stream->submit([&](sycl::handler &cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
cgh.parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
sycl_parallel_for(
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
nrows_y, scale, max_bias, m0,
+3 -1
View File
@@ -1,6 +1,7 @@
#include "sycl_hw.hpp"
// TODO: currently not used
/*
sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
sycl_hw_info res;
int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
@@ -11,3 +12,4 @@ sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
return res;
}
*/
+3
View File
@@ -10,6 +10,8 @@
namespace syclex = sycl::ext::oneapi::experimental;
// TODO: currently not used
/*
struct sycl_hw_info {
syclex::architecture arch;
int32_t device_id;
@@ -18,6 +20,7 @@ struct sycl_hw_info {
bool is_in_vector(std::vector<int> &vec, int item);
sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
*/
#endif // SYCL_HW_HPP
+3 -8
View File
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
sycl::range<3> gridDim(1, ne00, num_blocks);
stream->parallel_for(
sycl::nd_range<3>(
gridDim * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
timestep_embedding_f32(
x, dst, nb1, dim, max_period, item_ct1
);
});
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
});
}
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+12 -16
View File
@@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
// Submit kernel
if (C / H == WKV_BLOCK_SIZE) {
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
@@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
});
});
} else {
stream->submit([&](sycl::handler& cgh) {
sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
cgh.parallel_for(
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
sycl_parallel_for(
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
+64 -2
View File
@@ -1041,6 +1041,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
struct vk_instance_t {
vk::Instance instance;
bool debug_utils_support = false; // VK_EXT_debug_utils enabled
PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
std::vector<size_t> device_indices;
vk_device devices[GGML_VK_MAX_DEVICES];
};
@@ -1180,6 +1188,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
}
pipeline->compiled = true;
if (vk_instance.debug_utils_support) {
vk::DebugUtilsObjectNameInfoEXT duoni;
duoni.objectType = vk::ObjectType::ePipeline;
duoni.pObjectName = pipeline->name.c_str();
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
}
{
std::lock_guard<std::mutex> guard(device->mutex);
device->pipelines.insert({ pipeline->name, pipeline });
@@ -3561,6 +3577,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
static void ggml_vk_instance_init() {
if (vk_instance_initialized) {
return;
@@ -3581,7 +3599,7 @@ static void ggml_vk_instance_init() {
#ifdef __APPLE__
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
#endif
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
std::vector<const char*> layers;
if (validation_ext) {
@@ -3596,6 +3614,9 @@ static void ggml_vk_instance_init() {
extensions.push_back("VK_KHR_portability_enumeration");
}
#endif
if (debug_utils_ext) {
extensions.push_back("VK_EXT_debug_utils");
}
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
#ifdef __APPLE__
if (portability_enumeration_ext) {
@@ -3619,6 +3640,18 @@ static void ggml_vk_instance_init() {
vk_instance.instance = vk::createInstance(instance_create_info);
vk_instance_initialized = true;
if (debug_utils_ext) {
vk_instance.debug_utils_support = true;
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
}
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -9495,6 +9528,12 @@ static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer
UNUSED(buft);
}
static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
return vk_instance.devices[0]->suballocation_block_size;
UNUSED(buft);
}
// Should be changed to return device-specific host buffer type
// but that probably requires changes in llama.cpp
ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
@@ -9503,7 +9542,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
/* .get_name = */ ggml_backend_vk_host_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
@@ -9650,6 +9689,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if (vk_instance.debug_utils_support) {
vk::DebugUtilsLabelEXT dul = {};
dul.pLabelName = "ggml_backend_vk_graph_compute";
dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
}
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
@@ -10339,6 +10385,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
UNUSED(instance_extensions);
}
// Extension availability
static bool ggml_vk_instance_debug_utils_ext_available(
const std::vector<vk::ExtensionProperties> & instance_extensions) {
// Check for portability enumeration extension for MoltenVK support
for (const auto & properties : instance_extensions) {
if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
return true;
}
}
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
return false;
UNUSED(instance_extensions);
}
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) {
case VK_VENDOR_ID_INTEL:
+32 -2
View File
@@ -955,6 +955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"UPSCALE",
"PAD",
"PAD_REFLECT_1D",
"ROLL",
"ARANGE",
"TIMESTEP_EMBEDDING",
"ARGSORT",
@@ -985,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"upscale(x)",
"pad(x)",
"pad_reflect_1d(x)",
"roll(x)",
"arange(start, stop, step)",
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
@@ -1080,7 +1082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4341,6 +4343,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
return result;
}
// ggml_roll
struct ggml_tensor * ggml_roll(
struct ggml_context * ctx,
struct ggml_tensor * a,
int shift0,
int shift1,
int shift2,
int shift3) {
GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
GGML_ASSERT(abs(shift0) < a->ne[0]);
GGML_ASSERT(abs(shift1) < a->ne[1]);
GGML_ASSERT(abs(shift2) < a->ne[2]);
GGML_ASSERT(abs(shift3) < a->ne[3]);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
ggml_set_op_params_i32(result, 0, shift0);
ggml_set_op_params_i32(result, 1, shift1);
ggml_set_op_params_i32(result, 2, shift2);
ggml_set_op_params_i32(result, 3, shift3);
result->op = GGML_OP_ROLL;
result->src[0] = a;
return result;
}
// ggml_arange
struct ggml_tensor * ggml_arange(
+1
View File
@@ -198,6 +198,7 @@ class Keys:
MASK_ID = "tokenizer.ggml.mask_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_SEP = "tokenizer.ggml.add_sep_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
+3
View File
@@ -891,6 +891,9 @@ class GGUFWriter:
def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
def add_add_sep_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
+97 -4
View File
@@ -7,7 +7,10 @@ import os
from pathlib import Path
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
from sentencepiece import SentencePieceProcessor
try:
from sentencepiece import SentencePieceProcessor
except ImportError:
SentencePieceProcessor = None
import gguf
@@ -116,6 +119,7 @@ class SpecialVocab:
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer = None
tokenizer_file = path / 'tokenizer.json'
if tokenizer_file.is_file():
with open(tokenizer_file, encoding = 'utf-8') as f:
@@ -149,11 +153,97 @@ class SpecialVocab:
added_tokens = tokenizer.get('added_tokens', {})
else:
added_tokens = {}
tokenizer_config = None
tokenizer_config_file = path / 'tokenizer_config.json'
if not tokenizer_config_file.is_file():
if tokenizer_config_file.is_file():
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
if tokenizer:
special_bos = (tokenizer_config or {}).get('bos_token')
special_cls = (tokenizer_config or {}).get('cls_token')
special_eos = (tokenizer_config or {}).get('eos_token')
special_sep = (tokenizer_config or {}).get('sep_token')
if not special_bos and special_cls and tokenizer_config:
tokenizer_config['bos_token'] = special_bos = special_cls
if not special_eos and special_sep and tokenizer_config:
tokenizer_config['eos_token'] = special_eos = special_sep
if post_processor := tokenizer.get('post_processor'):
for processor in post_processor.get('processors', [post_processor]):
if processor.get('type') == 'RobertaProcessing':
self.add_special_token['bos'] = True
self.add_special_token['eos'] = True
self.add_special_token['sep'] = True
if not special_cls and tokenizer_config:
special_cls = processor.get('cls', [special_bos])[0]
tokenizer_config['cls_token'] = special_cls
if not special_sep and tokenizer_config:
special_sep = processor.get('sep', [special_eos])[0]
tokenizer_config['sep_token'] = special_sep
continue
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
# Only works with simple templates, **will** get it wrong on unusual sequences
if processor.get('type') == 'TemplateProcessing':
tmpl_single = processor.get('single', [])
tmpl_pair = processor.get('pair', [])
special_first = None
special_last = None
if len(tmpl_single) > 1:
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_bos = special_first
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
if special_first not in (special_bos, special_cls):
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_eos = special_last
elif special_last != special_eos:
if 'eot' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eot', )
tokenizer_config['eot_token'] = special_eos
elif 'eom' not in self.special_token_types:
self.special_token_types = tuple(self.special_token_types) + ('eom', )
tokenizer_config['eom_token'] = special_eos
else:
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
tokenizer_config['eos_token'] = special_eos = special_last
self.add_special_token['eos'] = True if special_last == special_eos else False
if special_last != special_eos:
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
if tmpl_pair:
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
if tmpl_a != 'A' or tmpl_b != 'B':
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
# A [sep] [eos] B
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
add_sep = False
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
if len(tmpl_pair) == 2:
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos):
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
self.add_special_token['sep'] = add_sep
if add_sep and not special_sep and tokenizer_config:
tokenizer_config['sep_token'] = special_eos
continue
if not tokenizer_config:
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
chat_template_alt = None
chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file():
@@ -302,6 +392,9 @@ class SentencePieceVocab(Vocab):
name = "spm"
def __init__(self, base_path: Path):
if SentencePieceProcessor is None:
raise RuntimeError("sentencepiece is not installed")
added_tokens: dict[str, int] = {}
if (fname_tokenizer := base_path / 'tokenizer.model').exists():
# normal location
+2 -2
View File
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
version = "0.17.0"
version = "0.17.1"
description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
@@ -22,7 +22,7 @@ python = ">=3.8"
numpy = ">=1.17"
tqdm = ">=4.27"
pyyaml = ">=5.1"
sentencepiece = ">=0.1.98,<=0.2.0"
sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
[tool.poetry.dev-dependencies]
+8 -3
View File
@@ -390,6 +390,7 @@ extern "C" {
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
void * tensor_types; // pointer to vector containing tensor types
void * prune_layers; // pointer to vector containing layer indices to prune
} llama_model_quantize_params;
typedef struct llama_logit_bias {
@@ -943,12 +944,14 @@ extern "C" {
// Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the memory state is restored to the state before this call
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
// Upon other return values, the memory state is restored to the state before this call
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted
// 2 - aborted (processed ubatches will remain in the context's memory)
// -1 - invalid input batch
// < -1 - error
// < -1 - fatal error (processed ubatches will remain in the context's memory)
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
@@ -1044,6 +1047,7 @@ extern "C" {
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
@@ -1087,6 +1091,7 @@ extern "C" {
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space.
@@ -0,0 +1,124 @@
{%- set today = strftime_now("%Y-%m-%d") %}
{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information or when the user's request requires up-to-date or specific data, you must use the available tools to fetch the information. Do not hesitate to use tools whenever they can provide a more accurate or complete response. If no relevant tools are available, then clearly state that you don't have the information and avoid making up anything.
If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\").
You are always very attentive to dates, and when asked about information at specific dates, you discard information that is at another date.
You follow these instructions in all languages, and always respond to the user in the language they use or request.
Next sections describe the capabilities that you have.
# WEB BROWSING INSTRUCTIONS
You cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.
# MULTI-MODAL INSTRUCTIONS
You have the ability to read images, but you cannot generate images. You also cannot transcribe audio files or videos.
You cannot read nor transcribe audio files or videos.
# TOOL CALLING INSTRUCTIONS
You may have access to tools that you can use to fetch information or perform actions. You must use these tools in the following situations:
1. When the request requires up-to-date information.
2. When the request requires specific data that you do not have in your knowledge base.
3. When the request involves actions that you cannot perform without tools.
Always prioritize using tools to provide the most accurate and helpful response. If tools are not available, inform the user that you cannot perform the requested action at the moment." %}
{{- bos_token }}
{%- set system_prompt = default_system_message %}
{%- set loop_messages = messages %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- if messages|length > 0 and messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set system_prompt = messages[0]['content'] %}
{%- else %}
{%- set system_prompt = messages[0]['content'][0]['text'] %}
{%- endif %}
{%- set loop_messages = messages[1:] %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- set ns = namespace(index=0) %}
{%- for message in loop_messages %}
{%- if not (message.role == "tool" or (message.get('tool_calls'))) %}
{%- if (message["role"] == "user") != (ns.index % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- set ns.index = ns.index + 1 %}
{%- endif %}
{%- endfor %}
{{- '[SYSTEM_PROMPT]' + system_prompt + '[/SYSTEM_PROMPT]' }}
{%- for message in loop_messages %}
{%- if message['role'] == 'system' %}
{%- if message['content'] is string %}
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
{%- else %}
{{- '[SYSTEM_PROMPT]' + message['content'][0]['text'] + '[/SYSTEM_PROMPT]' }}
{%- endif %}
{%- elif message['role'] == 'user' %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- '[AVAILABLE_TOOLS]' + tools|tojson + '[/AVAILABLE_TOOLS]' }}
{%- endif %}
{{- '[INST]' }}
{%- if message['content'] is string %}
{{- message['content'] }}
{%- else %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
{{- block['text'] }}
{%- elif block['type'] in ['image', 'image_url'] %}
{{- '[IMG]' }}
{%- else %}
{{- raise_exception('Only text and image blocks are supported in message content!') }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- '[/INST]' }}
{%- elif message['role'] == 'assistant' %}
{%- if message.get('tool_calls') %}
{%- for tool_call in message.tool_calls %}
{{- '[TOOL_CALLS]' + tool_call.function.name }}
{%- if not tool_call.id is defined or tool_call.id is not string or tool_call.id|length != 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
{%- endif %}
{{- '[CALL_ID]' + tool_call.id }}
{{- '[ARGS]' + tool_call['function']['arguments']|tojson }}
{%- endfor %}
{{- eos_token }}
{%- elif message['content'] is string %}
{{- message['content'] + eos_token }}
{%- else %}
{%- for block in message['content'] %}
{%- if block['type'] == 'text' %}
{{- block['text'] }}
{%- elif block['type'] in ['image', 'image_url'] %}
{{- '[IMG]' }}
{%- else %}
{{- raise_exception('Only text and image blocks are supported in assistant content!') }}
{%- endif %}
{%- endfor %}
{{- eos_token }}
{%- endif %}
{%- elif message['role'] == 'tool_results' or message['role'] == 'tool' %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{%- if not message.tool_call_id is defined or message.tool_call_id is not string or message['tool_call_id']|length != 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
{%- endif %}
{{- '[TOOL_RESULTS]' + message.tool_call_id + '[TOOL_CONTENT]' + content|string + '[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception('Only system, user, assistant, and tool roles are supported!') }}
{%- endif %}
{%- endfor %}
+1 -1
View File
@@ -1 +1 @@
8cda0a3c19f2c7dc493887353c42f6956bc268b1
9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff
+2 -1
View File
@@ -22,8 +22,9 @@ add_library(llama
llama-io.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-recurrent.cpp
llama-mmap.cpp
llama-model-loader.cpp
llama-model-saver.cpp
+24
View File
@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -197,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
@@ -1816,3 +1818,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
return LLM_TENSOR_INFOS.at(tensor);
}
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
return true;
default:
return false;
}
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
switch (arch) {
default:
return false;
}
}
+5
View File
@@ -151,6 +151,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
LLM_KV_ATTENTION_LAYER_INDICES,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -193,6 +194,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
@@ -439,3 +441,6 @@ const char * llm_arch_name(llm_arch arch);
llm_arch llm_arch_from_string(const std::string & name);
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
bool llm_arch_is_recurrent(const llm_arch & arch);
bool llm_arch_is_hybrid (const llm_arch & arch);
+574 -363
View File
File diff suppressed because it is too large Load Diff
+97 -69
View File
@@ -2,86 +2,44 @@
#include "llama.h"
#include "llama-cparams.h"
#include <array>
#include <vector>
#include <set>
#include <bitset>
#include <unordered_map>
// very similar to llama_batch,
// but has more metadata about sequences
// keep this struct lightweight
// it points to data in `llama_batch_allocr`
struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
// used for extracting sequence pooled embeddings
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
};
struct llama_sbatch_seq {
int32_t n_seq_id;
llama_seq_id * seq_id;
size_t offset;
size_t length;
};
// sequence-length-aware batch splitting
struct llama_sbatch {
// tokens left in this batch
size_t n_tokens;
size_t n_embd;
// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;
// buffers for the ubatches
// TODO: very hacky, this needs a complete rework
struct ubatch_data {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
};
std::vector<ubatch_data> udatas;
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch);
// make batches of equal-length sequences
llama_ubatch split_equal(size_t n_ubatch);
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};
// a helper for sanitizing and fulfilling a batch
// a helper for sanitizing, fulfilling and splitting a batch
class llama_batch_allocr {
public:
llama_batch_allocr();
llama_batch_allocr(uint32_t n_pos_per_embd);
// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
@@ -89,20 +47,57 @@ public:
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory,
bool embd_all);
uint32_t n_embd,
bool output_all);
const llama_batch & get_batch() const;
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
// min/max positions of each sequence in the current ubatch
llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;
// call once before splitting the batch to reset the internal state
void split_reset();
// simple split, unknown number of sequence sets of unequal lengths
llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
// a helper method for creating a well-defined ubatch of tokens
// TODO: support embeddings if needed in the future
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
private:
void clear();
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
// for debugging, start with LLAMA_BATCH_DEBUG=2
void ubatch_print(const llama_ubatch & ubatch, int debug);
llama_batch batch;
// only for debugging purposes
const llama_vocab * vocab;
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
const uint32_t n_pos_per_embd;
uint32_t n_embd;
uint32_t n_outputs;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
@@ -110,10 +105,43 @@ private:
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using idx_vec_t = std::vector<int32_t>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
// batch indices of the output
std::vector<int32_t> out_ids;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;
// llama_ubatch points to this data:
struct ubatch {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};
// current splitting state:
std::vector<ubatch> ubatches;
int debug;
};
+11 -6
View File
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
}
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (auto message : chat) {
std::string role(message->role);
if (role == "user") {
ss << "User: " << message->content << "\n\nAssistant:";
} else {
ss << message->content << "\n\n";
for (size_t i = 0; i < chat.size(); i++) {
std::string role(chat[i]->role);
if (role == "system") {
ss << "System: " << trim(chat[i]->content) << "\n\n";
} else if (role == "user") {
ss << "User: " << trim(chat[i]->content) << "\n\n";
if (i == chat.size() - 1) {
ss << "Assistant:";
}
} else if (role == "assistant") {
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
}
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
+101 -107
View File
@@ -20,7 +20,7 @@ llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model),
batch_allocr(std::make_unique<llama_batch_allocr>()) {
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us;
@@ -280,8 +280,8 @@ llama_context::llama_context(
// simulate full KV cache
const auto mstate = memory->init_full();
if (!mstate) {
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize KV cache");
}
@@ -289,7 +289,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@@ -300,7 +300,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes
{
auto * gf = graph_reserve(1, 1, 1, mstate.get());
auto * gf = graph_reserve(1, 1, 1, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers");
}
@@ -311,7 +311,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
optimize |= memory_force_optimize;
memory_force_optimize = false;
const auto mstate = memory->init_update(this, optimize);
switch (mstate->get_status()) {
const auto mctx = memory->init_update(this, optimize);
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
// noop
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
}
}
if (!mstate->apply()) {
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
}
}
// if the memory module did any computation, we have to reserve a new worst-case graph
{
const auto mstate = memory->init_full();
if (!mstate) {
throw std::runtime_error("failed to initialize memory state");
const auto mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
if (mstate && !mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
return nullptr;
}
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return nullptr;
}
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
}
int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr->get_batch();
const uint32_t n_tokens = balloc->get_n_tokens();
const uint32_t n_tokens = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_queued_tokens += n_tokens;
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
cross.v_embd.resize(cross.n_embd*cross.n_enc);
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
const auto & batch = balloc->get_batch();
// remember the sequence ids used during the encoding - needed for cross attention later
cross.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
const llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
}
}
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp);
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr->get_batch();
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const uint32_t n_tokens_all = batch.n_tokens;
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
const uint32_t n_tokens_all = balloc->get_n_tokens();
const uint32_t n_outputs_all = balloc->get_n_outputs();
if (embd_all) {
if (output_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -942,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// handle any pending defrags/shifts
kv_self_update(false);
llama_memory_state_ptr mstate;
llama_memory_context_ptr mctx;
while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
if (!mstate) {
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mctx) {
return -2;
}
switch (mstate->get_status()) {
switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS:
{
} break;
case LLAMA_MEMORY_STATUS_NO_UPDATE:
{
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
return -2;
}
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
did_optimize = true;
if (kv_self_update(true)) {
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
continue;
}
}
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
return -2;
}
@@ -996,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0;
do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();
// count the outputs in this ubatch
{
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
@@ -1019,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1028,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
// TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];
@@ -1105,27 +1094,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1136,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
n_outputs_prev += n_outputs;
} while (mstate->next());
} while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
@@ -1145,7 +1134,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = mstate->out_ids();
auto & out_ids = balloc->get_out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
@@ -1302,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) {
@@ -1318,11 +1307,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = n_outputs;
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
this->n_outputs = save_n_outputs;
@@ -1343,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
}
llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate) {
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx) {
return model.build_graph(
{
/*.ctx =*/ ctx,
@@ -1359,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.mstate =*/ mstate,
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
@@ -2039,7 +2028,12 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
const auto n_tokens_all = batch.n_tokens;
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
const uint32_t n_tokens_all = balloc->get_n_tokens();
n_queued_tokens += n_tokens_all;
@@ -2047,8 +2041,8 @@ void llama_context::opt_epoch_iter(
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
}
@@ -2061,17 +2055,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0;
do {
const auto & ubatch = mstate->get_ubatch();
const auto & ubatch = mctx->get_ubatch();
n_outputs = ubatch.n_tokens;
if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break;
}
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
struct ggml_context * ctx_compute_opt;
{
@@ -2106,7 +2100,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens;
} while (mstate->next());
} while (mctx->next());
}
}
+13 -13
View File
@@ -18,7 +18,7 @@ class llama_io_read_i;
class llama_io_write_i;
struct llama_memory_i;
struct llama_memory_state_i;
struct llama_memory_context_i;
struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs
@@ -93,14 +93,14 @@ struct llama_context {
int32_t il_end);
// process a single ubatch with a specific graph type
// if memory_state is provided, it will be applied first to the context's memory
// if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch(
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_state_i * mstate,
ggml_status & ret);
const llama_ubatch & ubatch,
llm_graph_type gtype,
llama_memory_context_i * mctx,
ggml_status & ret);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);
@@ -197,15 +197,15 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_state_i * mstate);
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype,
const llama_memory_context_i * mctx);
llm_graph_cb graph_get_cb() const;
@@ -247,7 +247,7 @@ private:
std::map<llama_seq_id, std::vector<float>> embd_seq;
// reuse the batch_allocr to avoid unnecessary memory allocations
std::unique_ptr<llama_batch_allocr> batch_allocr;
std::unique_ptr<llama_batch_allocr> balloc;
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
+337 -290
View File
@@ -6,7 +6,8 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-recurrent.h"
#include <cassert>
#include <cmath>
@@ -86,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) {
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
mctx->set_input_pos_bucket(pos_bucket, ubatch);
}
}
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
GGML_ASSERT(out_ids);
if (!out_ids) {
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
int32_t * data = (int32_t *) out_ids->data;
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
int32_t * data = (int32_t *) out_ids->data;
if (n_outputs == n_tokens) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
} else if (ubatch->output) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (ubatch->output[i]) {
data[n_outputs++] = i;
}
}
// the graph needs to have been passed the correct number of outputs
GGML_ASSERT(n_outputs == n_outputs);
} else if (n_outputs == 1) {
// only keep last output
data[0] = n_tokens - 1;
} else {
GGML_ASSERT(n_outputs == 0);
}
if (n_outputs == n_tokens) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
return;
}
GGML_ASSERT(ubatch->output);
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (ubatch->output[i]) {
data[n_outputs++] = i;
}
}
}
@@ -129,127 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
GGML_ASSERT(mean);
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
float * data = (float *) mean->data;
memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
std::vector<uint64_t> sum(n_tokens, 0);
std::vector<uint64_t> sums(n_seqs_unq, 0);
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
sum[seq_id] += ubatch->n_seq_tokens;
}
std::vector<float> div(n_tokens, 0.0f);
for (int i = 0; i < n_tokens; ++i) {
const uint64_t s = sum[i];
if (s > 0) {
div[i] = 1.0f/float(s);
sums[seq_idx] += ubatch->n_seq_tokens;
}
}
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
std::vector<float> div(n_seqs_unq, 0.0f);
for (int s = 0; s < n_seqs_unq; ++s) {
const uint64_t sum = sums[s];
if (sum > 0) {
div[s] = 1.0f/float(sum);
}
}
for (int i = 0; i < n_seq_tokens; ++i) {
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
for (int j = 0; j < n_seq_tokens; ++j) {
data[seq_idx*n_tokens + i + j] = div[seq_idx];
}
}
}
}
}
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
if (cparams.embeddings && (
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
)) {
GGML_ASSERT(cls);
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
if (pos == 0) {
data[seq_id] = s*n_seq_tokens + i;
}
data[seq_idx] = i;
}
}
}
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(cls);
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
uint32_t * data = (uint32_t *) cls->data;
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
std::vector<int> last_pos(n_tokens, -1);
std::vector<int> last_row(n_tokens, -1);
std::vector<int> last_pos(n_seqs_unq, -1);
std::vector<int> last_row(n_seqs_unq, -1);
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int i = 0; i < n_tokens; ++i) {
const llama_pos pos = ubatch->pos[i];
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
const int32_t seq_idx = ubatch->seq_idx[seq_id];
for (int i = 0; i < n_seq_tokens; ++i) {
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
if (pos >= last_pos[seq_id]) {
last_pos[seq_id] = pos;
last_row[seq_id] = s*n_seq_tokens + i;
if (pos >= last_pos[seq_idx]) {
last_pos[seq_idx] = pos;
last_row[seq_idx] = i;
}
}
}
for (int i = 0; i < n_tokens; ++i) {
if (last_row[i] >= 0) {
data[i] = last_row[i];
for (int s = 0; s < n_seqs_unq; ++s) {
if (last_row[s] >= 0) {
data[s] = last_row[s];
}
}
}
}
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
const int64_t n_kv = kv_state->get_n_kv();
const int64_t n_rs = mctx->get_n_rs();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
int32_t * data = (int32_t *) s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
data[i] = kv_state->s_copy(i);
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->s_copy(i);
}
}
}
@@ -265,89 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
if (kq_mask) {
if (cparams.causal_attn) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
GGML_ASSERT(kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
float * data = (float *) kq_mask->data;
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int h = 0; h < 1; ++h) {
for (int i1 = 0; i1 < n_tokens; ++i1) {
const llama_seq_id s1 = ubatch->seq_id[i1][0];
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
for (int i0 = 0; i0 < n_tokens; ++i0) {
float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
}
// TODO: reimplement this like in llama_kv_cache_unified
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
} else {
f = 0.0f;
}
break;
}
}
}
} else {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_stride = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
} else {
f = 0.0f;
}
break;
}
}
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}
for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
}
}
}
@@ -355,51 +282,71 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_kq_mask_swa) {
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
if (cross_kq_mask) {
const int64_t n_enc = cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(cross_kq_mask);
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
const int64_t n_enc = cross_kq_mask->ne[0];
const int64_t n_tokens = ubatch->n_tokens;
float * data = (float *) cross_kq_mask->data;
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_enc; ++i) {
float f = -INFINITY;
// TODO: fix indexing [UBATCH_IDX]
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[j][s];
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
f = 0.0f;
}
float * data = (float *) cross_kq_mask->data;
for (int h = 0; h < 1; ++h) {
for (int i = 0; i < n_tokens; ++i) {
for (int j = 0; j < n_enc; ++j) {
float f = -INFINITY;
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[i][s];
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
f = 0.0f;
}
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_enc; ++j) {
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
}
}
}
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
int32_t * data = (int32_t *) s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
@@ -442,16 +389,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
backend_cpu (params.backend_cpu),
cvec (params.cvec),
loras (params.loras),
mstate (params.mstate),
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) {
}
int64_t llm_graph_context::n_pos_per_embd() const {
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
}
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
if (cb_func) {
cb_func(ubatch, cur, name, il);
@@ -896,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
}
ggml_tensor * llm_graph_context::build_inp_pos() const {
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
auto & cur = inp->pos;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
ggml_set_input(cur);
res->add_input(std::move(inp));
@@ -923,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
}
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
// but this would make the graph topology depend on the number of output tokens, which can interere with
// features that require constant topology such as pipline parallelism
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
//if (n_outputs < n_tokens) {
// return nullptr;
//}
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
auto & cur = inp->out_ids;
@@ -940,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
auto & cur = inp->mean;
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
ggml_set_input(cur);
res->add_input(std::move(inp));
@@ -953,24 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
auto & cur = inp->cls;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
ggml_set_input(cur);
res->add_input(std::move(inp));
@@ -1016,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
}
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
const auto n_kv = kv_state->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
auto & cur = inp->pos_bucket;
@@ -1047,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
return pos_bias;
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
}
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
@@ -1222,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = kv_state->get_n_kv();
const auto n_kv = mctx_cur->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1259,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
@@ -1291,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
@@ -1339,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
const bool is_swa = hparams.is_swa(il);
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
@@ -1430,20 +1361,99 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// store to KV cache
{
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
@@ -1474,22 +1484,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
return output_states;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
const auto n_rs = mctx_cur->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy);
return (llm_graph_input_rs *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count;
const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
ggml_tensor * token_shift = build_recurrent_state(
gf, token_shift_all, state_copy,
hparams.n_embd_k_s(), n_seqs);
ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
hparams.n_embd_r(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
@@ -1500,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;
const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = kv_state->get_head();
const auto kv_head = mctx_cur->get_head();
return ggml_cpy(
ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
);
}
+107 -39
View File
@@ -17,11 +17,12 @@ struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
struct llama_memory_state_i;
struct llama_memory_context_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
class llama_kv_cache_recurrent_state;
class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -94,14 +95,14 @@ public:
class llm_graph_input_pos : public llm_graph_input_i {
public:
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
virtual ~llm_graph_input_pos() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const int64_t n_pos_per_embd = 1;
const uint32_t n_pos_per_embd = 1;
};
// temperature tuning, used by llama4
@@ -135,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -143,7 +144,8 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
};
class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -188,16 +190,16 @@ public:
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
class llm_graph_input_rs : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_s_copy() = default;
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_recurrent_state * kv_state;
const llama_memory_recurrent_context * mctx;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -237,10 +239,10 @@ public:
llm_graph_input_attn_kv_unified(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_state * kv_state) :
const llama_kv_cache_unified_context * mctx) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified() = default;
@@ -254,7 +256,7 @@ public:
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@@ -262,10 +264,10 @@ public:
llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_state * kv_state) :
const llama_kv_cache_unified_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
@@ -282,7 +284,7 @@ public:
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_unified_iswa_state * kv_state;
const llama_kv_cache_unified_iswa_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -300,6 +302,33 @@ public:
const llama_cross * cross = nullptr;
};
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_memory_hybrid_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_memory_hybrid_context * mctx;
};
//
// llm_graph_result
//
@@ -373,10 +402,10 @@ struct llm_graph_params {
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_state_i * mstate;
const llama_cross * cross;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
uint32_t n_outputs;
@@ -425,10 +454,10 @@ struct llm_graph_context {
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_state_i * mstate;
const llama_cross * cross;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
const llama_memory_context_i * mctx;
const llama_cross * cross;
const llm_graph_cb & cb_func;
@@ -436,8 +465,6 @@ struct llm_graph_context {
llm_graph_context(const llm_graph_params & params);
int64_t n_pos_per_embd() const;
void cb(ggml_tensor * cur, const char * name, int il) const;
//
@@ -508,13 +535,14 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
//
// attention
//
@@ -589,22 +617,62 @@ struct llm_graph_context {
float kq_scale,
int il) const;
ggml_tensor * build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies = false) const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rs(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(
+10 -2
View File
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv;
}
uint32_t llama_hparams::n_embd_k_s() const {
uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) {
// for RWKV models
return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
}
uint32_t llama_hparams::n_embd_v_s() const {
uint32_t llama_hparams::n_embd_s() const {
if (wkv_head_size != 0) {
// corresponds to RWKV's wkv_states size
return n_embd * wkv_head_size;
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
return ssm_d_state * ssm_d_inner;
}
bool llama_hparams::is_recurrent(uint32_t il) const {
return recurrent_layer_arr[il];
}
uint32_t llama_hparams::n_pos_per_embd() const {
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
}
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return swa_layers[il];
+10 -2
View File
@@ -115,6 +115,9 @@ struct llama_hparams {
uint32_t ssm_d_state = 0;
uint32_t ssm_dt_rank = 0;
// for hybrid state space models
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
bool ssm_dt_b_c_rms = false;
float f_clamp_kqv = 0.0f;
@@ -181,10 +184,15 @@ struct llama_hparams {
// dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_k_s() const;
uint32_t n_embd_r() const;
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
uint32_t n_embd_s() const;
// whether or not the given layer is recurrent (for hybrid models)
bool is_recurrent(uint32_t il) const;
uint32_t n_pos_per_embd() const;
bool is_swa(uint32_t il) const;
};
+55 -61
View File
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_simple(n_ubatch);
while (sbatch.n_tokens > 0) {
auto ubatch = sbatch.split_simple(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(ubatch);
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads_base = kv_base->prepare(ubatches);
@@ -122,20 +125,23 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
// if it fails, try equal split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
while (sbatch.n_tokens > 0) {
auto ubatch = sbatch.split_equal(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(ubatch);
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads_base = kv_base->prepare(ubatches);
@@ -150,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
}
bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -191,52 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
}
//
// llama_kv_cache_unified_iswa_state
// llama_kv_cache_unified_iswa_context
//
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_base = kv->get_base()->init_full();
state_swa = kv->get_swa ()->init_full();
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) :
ctx_base(kv->get_base()->init_full()),
ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_base = kv->get_base()->init_update(lctx, optimize);
state_swa = kv->get_swa ()->init_update(lctx, optimize);
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)),
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)) {
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
}
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
bool llama_kv_cache_unified_iswa_state::next() {
bool llama_kv_cache_unified_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_base->next();
state_swa ->next();
ctx_base->next();
ctx_swa ->next();
if (++i_next >= ubatches.size()) {
return false;
@@ -245,41 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
return true;
}
bool llama_kv_cache_unified_iswa_state::apply() {
bool llama_kv_cache_unified_iswa_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true;
res = res & state_base->apply();
res = res & state_swa ->apply();
res = res & ctx_base->apply();
res = res & ctx_swa ->apply();
return res;
}
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
}
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
}
+21 -26
View File
@@ -31,14 +31,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
@@ -72,62 +72,57 @@ private:
std::unique_ptr<llama_kv_cache_unified> kv_swa;
};
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public:
// used for errors
llama_kv_cache_unified_iswa_state(llama_memory_status status);
llama_kv_cache_unified_iswa_context(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_iswa_state(
// used to create a full-cache context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv);
// used to create an update state
llama_kv_cache_unified_iswa_state(
// used to create an update context
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_context * lctx,
bool optimize);
// used to create a state from a batch
llama_kv_cache_unified_iswa_state(
// used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv,
llama_sbatch sbatch,
std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_state();
virtual ~llama_kv_cache_unified_iswa_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_iswa_state specific API
// llama_kv_cache_unified_iswa_context specific API
//
const llama_kv_cache_unified_state * get_base() const;
const llama_kv_cache_unified_state * get_swa() const;
const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_context * get_swa() const;
private:
llama_memory_status status;
//llama_kv_cache_unified_iswa * kv;
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
llama_memory_state_ptr state_base;
llama_memory_state_ptr state_swa;
const llama_memory_context_ptr ctx_base;
const llama_memory_context_ptr ctx_swa;
const llama_memory_status status;
};
+91 -111
View File
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
continue;
}
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
const char * dev_name = "CPU";
@@ -307,18 +307,24 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
return cells.seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) {
GGML_UNUSED(embd_all);
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
ubatches.push_back(sbatch.split_simple(n_ubatch));
while (true) {
auto ubatch = balloc.split_simple(n_ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
auto heads = prepare(ubatches);
@@ -326,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
break;
}
return std::make_unique<llama_kv_cache_unified_state>(
this, std::move(sbatch), std::move(heads), std::move(ubatches));
return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_state>(this);
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_context>(this);
}
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
bool do_shift = get_has_shift();
defrag_info dinfo;
@@ -367,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
}
}
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
}
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
}
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
}
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1;
}
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
const uint32_t idx = s*ubatch.n_seq_tokens + j;
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
if (!cells.is_empty(head_cur + idx)) {
assert(cells.seq_count(head_cur + idx) == 1);
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
const llama_pos pos = cells.pos_get(head_cur + i);
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
const llama_pos pos = cells.pos_get(head_cur + idx);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i);
}
cells.rm(head_cur + idx);
}
cells.pos_set(head_cur + i, ubatch.pos[i]);
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
// TODO: fix indexing [UBATCH_IDX]
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
}
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
}
}
@@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
}
}
// move the head at the end of the slot
head = head_cur + ubatch.n_tokens;
}
@@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
}
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens;
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
const uint32_t n_seqs = ubatch->n_seqs;
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
@@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (uint32_t i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = ubatch->seq_id[i][0];
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[i];
const llama_pos p1 = ubatch->pos[idx];
for (uint32_t j = 0; j < n_kv; ++j) {
float f = 0.0f;
for (uint32_t i = 0; i < n_kv; ++i) {
float f = 0.0f;
bool masked = false;
bool masked = false;
if (cells.is_empty(j)) {
masked = true;
} else {
const llama_pos p0 = cells.pos_get(j);
if (cells.is_empty(i)) {
masked = true;
} else {
const llama_pos p0 = cells.pos_get(i);
// mask the token if not the same sequence
masked = masked || (!cells.seq_has(j, seq_id));
// mask the token if not the same sequence
masked = masked || (!cells.seq_has(i, seq_id));
// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// mask future tokens
masked = masked || (causal_attn && p0 > p1);
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
// apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
if (masked) {
f = -INFINITY;
}
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
}
if (masked) {
f = -INFINITY;
}
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
}
}
// mask padded tokens
if (data) {
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
for (uint32_t i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (uint32_t j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
@@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
const int32_t n_kv = dst->ne[0];
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) {
for (int i = 0; i < n_tokens; ++i) {
for (int j = 0; j < n_kv; ++j) {
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
}
}
}
@@ -1430,7 +1420,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Write key type
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1442,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1466,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Write value type
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch;
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
llama_batch_allocr balloc(hparams.n_pos_per_embd());
ubatch.n_tokens = cell_count;
ubatch.n_seq_tokens = cell_count;
ubatch.n_seqs = 1;
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
@@ -1621,7 +1608,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
// Read type of key
int32_t k_type_i_ref;
@@ -1651,7 +1638,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;
@@ -1681,7 +1668,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
for (const auto & layer : layers) {
const uint32_t il = layer.il;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
// Read type of value
int32_t v_type_i_ref;
@@ -1723,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
}
//
// llama_kv_cache_unified_state
// llama_kv_cache_unified_context
//
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size();
head = 0;
}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
@@ -1744,16 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
}
}
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_sbatch sbatch,
llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
}
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_state::next() {
bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) {
@@ -1763,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
return true;
}
bool llama_kv_cache_unified_state::apply() {
bool llama_kv_cache_unified_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
// no ubatches -> this is a KV cache update
@@ -1781,51 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
return true;
}
std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_unified_state::get_status() const {
llama_memory_status llama_kv_cache_unified_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
return n_kv;
}
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv);
}
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv);
}
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head);
}
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head);
}
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn);
}
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
+17 -22
View File
@@ -56,14 +56,14 @@ public:
// llama_memory_i
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
@@ -208,49 +208,46 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_kv_cache_unified_state : public llama_memory_state_i {
class llama_kv_cache_unified_context : public llama_memory_context_i {
public:
// some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors
llama_kv_cache_unified_state(llama_memory_status status);
llama_kv_cache_unified_context(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_unified_state(
// used to create a full-cache context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv);
// used to create an update state
llama_kv_cache_unified_state(
// used to create an update context
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_context * lctx,
bool do_shift,
defrag_info dinfo);
// used to create a decode state from a batch
llama_kv_cache_unified_state(
// used to create a batch procesing context from a batch
llama_kv_cache_unified_context(
llama_kv_cache_unified * kv,
llama_sbatch sbatch,
ubatch_heads heads,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state();
virtual ~llama_kv_cache_unified_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_unified_state specific API
// llama_kv_cache_unified_context specific API
//
uint32_t get_n_kv() const;
@@ -275,7 +272,7 @@ private:
llama_context * lctx;
//
// update state
// update context
//
bool do_shift = false;
@@ -283,11 +280,9 @@ private:
defrag_info dinfo;
//
// batch processing state
// batch processing context
//
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
+35 -11
View File
@@ -7,6 +7,7 @@
#include <cassert>
#include <vector>
#include <set>
#include <map>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
@@ -164,7 +165,7 @@ public:
assert(seq_id >= 0);
seq[i].reset(seq_id);
seq_pos[seq_id].erase(pos[i]);
seq_pos_dec(seq_id, pos[i]);
if (seq[i].none()) {
pos[i] = -1;
@@ -187,7 +188,7 @@ public:
seq[i].reset();
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
seq_pos_inc(seq_id, pos[i]);
return false;
}
@@ -232,7 +233,7 @@ public:
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
seq_pos[seq_id].insert(pos[i]);
seq_pos_inc(seq_id, pos[i]);
}
// return the sequence id of this cell
@@ -259,7 +260,9 @@ public:
return -1;
}
return *seq_pos[seq_id].begin();
assert(seq_pos[seq_id].begin()->second > 0);
return seq_pos[seq_id].begin()->first;
}
// the maximum position of sequence seq_id currently present in any of the cells
@@ -272,7 +275,9 @@ public:
return -1;
}
return *seq_pos[seq_id].rbegin();
assert(seq_pos[seq_id].rbegin()->second > 0);
return seq_pos[seq_id].rbegin()->first;
}
// note: call only if the cell is not empty
@@ -384,22 +389,41 @@ private:
//
std::vector<llama_pos> shift;
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
std::vector<bits_t> seq;
std::vector<seq_set_t> seq;
// the set seq_pos[s] tells us which positions are currently present for sequence s
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
// if the position p is not present, seq_pos[s][p] is not set
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
//
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
// - during performing a cache reuse via (rm + add)
// - some vision models have input embeddings with repeating positions
//
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
// helper functions for updating `seq_pos`, once cell at a time:
void seq_pos_dec(llama_seq_id s, llama_pos p) {
auto it = seq_pos[s].find(p);
assert(it != seq_pos[s].end());
if (--it->second == 0) {
seq_pos[s].erase(it);
}
}
void seq_pos_inc(llama_seq_id s, llama_pos p) {
seq_pos[s][p]++;
}
// remove cell i
void seq_pos_rm(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].erase(pos[i]);
seq_pos_dec(s, pos[i]);
}
}
}
@@ -408,7 +432,7 @@ private:
void seq_pos_add(uint32_t i) {
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
if (seq[i].test(s)) {
seq_pos[s].insert(pos[i]);
seq_pos_inc(s, pos[i]);
}
}
}
+246
View File
@@ -0,0 +1,246 @@
#include "llama-memory-hybrid.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-context.h"
//
// llama_memory_hybrid
//
llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache_unified(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k,
type_v,
v_trans,
offload,
kv_size,
n_seq_max,
n_pad,
n_swa,
swa_type
)),
mem_recr(new llama_memory_recurrent(
model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r,
type_s,
offload,
rs_size,
n_seq_max
)) {}
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
// follow the recurrent pattern for creating the ubatch splits
std::vector<llama_ubatch> ubatches;
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
// prepare the attention cache
auto heads_attn = mem_attn->prepare(ubatches);
if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_hybrid_context>(
this, std::move(heads_attn), std::move(ubatches));
} while(false);
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_memory_hybrid::init_full() {
return std::make_unique<llama_memory_hybrid_context>(this);
}
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
}
bool llama_memory_hybrid::get_can_shift() const {
// Shifting is trivially supported for recurrent
return mem_attn->get_can_shift();
}
void llama_memory_hybrid::clear(bool data) {
mem_attn->clear(data);
mem_recr->clear(data);
}
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// Try removing from the recurrent cache first since it may fail. If it does
// fail, the cache will not have been mutated.
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
return false;
}
return mem_attn->seq_rm(seq_id, p0, p1);
}
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
mem_attn->seq_keep(seq_id);
mem_recr->seq_keep(seq_id);
}
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
mem_attn->seq_add(seq_id, p0, p1, shift);
mem_recr->seq_add(seq_id, p0, p1, shift);
}
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
mem_attn->seq_div(seq_id, p0, p1, d);
mem_recr->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
// the min of the total cache is the max of the two caches' min values
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
}
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
// the max of the total cache is the min of the two caches' max values
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
}
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
}
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
return mem_attn.get();
}
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
return mem_recr.get();
}
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
ctx_attn(mem->get_mem_attn()->init_full()),
ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
llama_context * lctx,
bool optimize) :
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
bool llama_memory_hybrid_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_attn->next();
ctx_recr->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_memory_hybrid_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true;
res = res & ctx_attn->apply();
res = res & ctx_recr->apply();
return res;
}
llama_memory_status llama_memory_hybrid_context::get_status() const {
return status;
}
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
}
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
}
+138
View File
@@ -0,0 +1,138 @@
#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-unified.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
#include <memory>
#include <vector>
//
// llama_memory_hybrid
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
// support models where each layer may be either attention-based or recurrent
class llama_memory_hybrid : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);
~llama_memory_hybrid() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
//
// llama_memory_hybrid specific API
//
llama_kv_cache_unified * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_context : public llama_memory_context_i {
public:
// init failure
explicit llama_memory_hybrid_context(llama_memory_status status);
// init full
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
// init update
explicit llama_memory_hybrid_context(
llama_memory_hybrid * mem,
llama_context * lctx,
bool optimize);
// init success
llama_memory_hybrid_context(
llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_context() = default;
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_hybrid_context
//
const llama_kv_cache_unified_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_attn;
const llama_memory_context_ptr ctx_recr;
const llama_memory_status status;
};
@@ -1,4 +1,4 @@
#include "llama-kv-cache-recurrent.h"
#include "llama-memory-recurrent.h"
#include "llama-impl.h"
#include "llama-io.h"
@@ -12,27 +12,28 @@
#include <stdexcept>
//
// llama_kv_cache_recurrent
// llama_memory_recurrent
//
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
llama_memory_recurrent::llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
head = 0;
size = kv_size;
size = mem_size;
used = 0;
cells.clear();
cells.resize(kv_size);
cells.resize(mem_size);
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
return it->second;
};
k_l.reserve(n_layer);
v_l.reserve(n_layer);
r_l.resize(n_layer);
s_l.resize(n_layer);
for (int i = 0; i < n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
if (filter && !filter(i)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
continue;
}
const char * dev_name = "CPU";
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
throw std::runtime_error("failed to create ggml context for kv cache");
}
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
k_l.push_back(k);
v_l.push_back(v);
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
ggml_format_name(r, "cache_r_l%d", i);
ggml_format_name(s, "cache_s_l%d", i);
r_l[i] = r;
s_l[i] = s;
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
@@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
}
{
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
const size_t memory_size_r = size_r_bytes();
const size_t memory_size_s = size_s_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
}
}
void llama_kv_cache_recurrent::clear(bool data) {
void llama_memory_recurrent::clear(bool data) {
for (int32_t i = 0; i < (int32_t) size; ++i) {
cells[i].pos = -1;
cells[i].seq_id.clear();
@@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
}
}
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
uint32_t new_head = size;
if (p0 < 0) {
@@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
if (0 <= seq_id) {
int32_t & tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
const kv_cell & cell = cells[tail_id];
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
return false;
@@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
return true;
}
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
if (seq_id_src == seq_id_dst) {
return;
}
@@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
}
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
kv_cell & tail_src = cells[seq_id_src];
kv_cell & tail_dst = cells[seq_id_dst];
auto & tail_src = cells[seq_id_src];
auto & tail_dst = cells[seq_id_dst];
if (tail_dst.tail >= 0) {
// clear destination seq_id if it wasn't empty
kv_cell & cell_dst = cells[tail_dst.tail];
auto & cell_dst = cells[tail_dst.tail];
cell_dst.seq_id.erase(seq_id_dst);
tail_dst.tail = -1;
@@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
}
}
if (tail_src.tail >= 0) {
kv_cell & cell_src = cells[tail_src.tail];
auto & cell_src = cells[tail_src.tail];
cell_src.seq_id.insert(seq_id_dst);
tail_dst.tail = tail_src.tail;
@@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
}
}
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
uint32_t new_head = size;
for (uint32_t i = 0; i < size; ++i) {
@@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
}
}
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
if (shift == 0) {
return;
}
@@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
if (0 <= seq_id && seq_id < (int64_t) size) {
const int32_t tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
kv_cell & cell = cells[tail_id];
auto & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
cell.pos += shift;
}
@@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
}
}
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (d == 1) {
return;
}
@@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
if (0 <= seq_id && seq_id < (int64_t) size) {
const int32_t tail_id = cells[seq_id].tail;
if (tail_id >= 0) {
kv_cell & cell = cells[tail_id];
auto & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
cell.pos /= d;
}
@@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
}
}
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
llama_pos result = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < size; ++i) {
@@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
return result;
}
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = -1;
for (uint32_t i = 0; i < size; ++i) {
@@ -359,43 +362,45 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch);
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch);
}
ubatches.push_back(ubatch);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (!prepare(ubatches)) {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
llama_memory_context_ptr llama_memory_recurrent::init_full() {
return std::make_unique<llama_memory_recurrent_context>(this);
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx);
GGML_UNUSED(optimize);
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
}
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
// simply remember the full state because it is very small for this type of cache
// TODO: optimize
auto org_cells = cells;
@@ -419,10 +424,9 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
return success;
}
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_seqs = ubatch.n_seqs;
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
@@ -442,9 +446,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
// everything should fit if all seq_ids are smaller than the max
for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t n_seq_id = ubatch.n_seq_id[s];
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
const uint32_t n_seq_id = ubatch.n_seq_id[i];
for (uint32_t j = 0; j < n_seq_id; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j];
const llama_seq_id seq_id = ubatch.seq_id[i][j];
if (seq_id < 0 || (uint32_t) seq_id >= size) {
// too big seq_id
@@ -453,9 +459,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
return false;
}
if (j > 0) {
kv_cell & seq = cells[seq_id];
auto & seq = cells[seq_id];
if (seq.tail >= 0) {
kv_cell & cell = cells[seq.tail];
auto & cell = cells[seq.tail];
// clear cells from seq_ids that become shared
// (should not normally happen, but let's handle it anyway)
cell.seq_id.erase(seq_id);
@@ -475,7 +481,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
std::vector<int32_t> tails_verif;
tails_verif.assign(size, -1);
for (uint32_t i = 0; i < size; ++i) {
kv_cell & cell = cells[i];
auto & cell = cells[i];
for (llama_seq_id seq_id : cell.seq_id) {
if (tails_verif[seq_id] != -1) {
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
@@ -496,28 +502,29 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
for (uint32_t i = 0; i < size; ++i) {
if (next_empty_cell >= size) { next_empty_cell -= size; }
kv_cell & cell = cells[next_empty_cell];
auto & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
next_empty_cell += 1;
}
// find usable cell range
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
kv_cell & seq_meta = cells[seq_id];
const uint32_t i = s*n_seq_tokens;
const llama_seq_id seq_id = ubatch.seq_id[i][0];
auto & seq_meta = cells[seq_id];
bool has_cell = false;
if (seq_meta.tail >= 0) {
kv_cell & cell = cells[seq_meta.tail];
auto & cell = cells[seq_meta.tail];
GGML_ASSERT(cell.has_seq_id(seq_id));
// does this seq_id "own" the cell?
if (cell.seq_id.size() == 1) { has_cell = true; }
}
if (!has_cell) {
kv_cell & empty_cell = cells[next_empty_cell];
auto & empty_cell = cells[next_empty_cell];
GGML_ASSERT(empty_cell.is_empty());
// copy old tail into the empty cell
if (seq_meta.tail >= 0) {
kv_cell & orig_cell = cells[seq_meta.tail];
auto & orig_cell = cells[seq_meta.tail];
empty_cell.pos = orig_cell.pos;
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
@@ -527,10 +534,10 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
seq_meta.tail = next_empty_cell;
// find next empty cell
if (s + 1 < n_seqs) {
for (uint32_t i = 0; i < size; ++i) {
for (uint32_t j = 0; j < size; ++j) {
next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
kv_cell & cell = cells[next_empty_cell];
auto & cell = cells[next_empty_cell];
if (cell.is_empty()) { break; }
}
}
@@ -541,19 +548,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
// gather and re-order
for (uint32_t s = 0; s < n_seqs; ++s) {
const uint32_t i = s*n_seq_tokens;
const int32_t dst_id = s + min;
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
if (dst_id != src_id) {
kv_cell & dst_cell = cells[dst_id];
kv_cell & src_cell = cells[src_id];
auto & dst_cell = cells[dst_id];
auto & src_cell = cells[src_id];
std::swap(dst_cell.pos, src_cell.pos);
std::swap(dst_cell.src, src_cell.src);
std::swap(dst_cell.seq_id, src_cell.seq_id);
// swap tails
for (uint32_t i = 0; i < size; ++i) {
int32_t & tail = cells[i].tail;
for (uint32_t j = 0; j < size; ++j) {
int32_t & tail = cells[j].tail;
if (tail == src_id) {
tail = dst_id;
} else if (tail == dst_id) {
@@ -565,20 +573,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
// update the pos of the used seqs
for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
const uint32_t i = s*n_seq_tokens;
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
const int32_t cell_id = s + min;
kv_cell & cell = cells[cell_id];
auto & cell = cells[cell_id];
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
// What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
}
cell.pos = last_pos;
cell.seq_id.clear();
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[s][j];
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
const llama_seq_id seq_id = ubatch.seq_id[i][j];
cell.seq_id.insert(seq_id);
cells[seq_id].tail = cell_id;
}
@@ -620,18 +629,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
head = min;
n = max - min + 1;
used = std::count_if(cells.begin(), cells.end(),
[](const kv_cell & cell){ return !cell.is_empty(); });
[](const mem_cell & cell){ return !cell.is_empty(); });
// sanity check
return n >= n_seqs;
}
bool llama_kv_cache_recurrent::get_can_shift() const {
bool llama_memory_recurrent::get_can_shift() const {
// shifting the pos is trivial for recurrent models
return true;
}
size_t llama_kv_cache_recurrent::total_size() const {
size_t llama_memory_recurrent::total_size() const {
size_t size = 0;
for (const auto & buf : bufs) {
size += ggml_backend_buffer_get_size(buf.get());
@@ -640,27 +649,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
return size;
}
size_t llama_kv_cache_recurrent::size_k_bytes() const {
size_t size_k_bytes = 0;
size_t llama_memory_recurrent::size_r_bytes() const {
size_t size_r_bytes = 0;
for (const auto & k : k_l) {
size_k_bytes += ggml_nbytes(k);
for (const auto & r : r_l) {
if (r != nullptr) {
size_r_bytes += ggml_nbytes(r);
}
}
return size_k_bytes;
return size_r_bytes;
}
size_t llama_kv_cache_recurrent::size_v_bytes() const {
size_t size_v_bytes = 0;
size_t llama_memory_recurrent::size_s_bytes() const {
size_t size_s_bytes = 0;
for (const auto & v : v_l) {
size_v_bytes += ggml_nbytes(v);
for (const auto & s : s_l) {
if (s != nullptr) {
size_s_bytes += ggml_nbytes(s);
}
}
return size_v_bytes;
return size_s_bytes;
}
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
@@ -698,7 +711,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
state_write_data(io, cell_ranges);
}
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
uint32_t cell_count;
io.read_to(&cell_count, sizeof(cell_count));
@@ -717,7 +730,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
}
}
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = cells[i];
@@ -736,98 +749,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
}
}
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
const uint32_t v_trans = 0;
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
const uint32_t s_trans = 0;
const uint32_t n_layer = hparams.n_layer;
io.write(&v_trans, sizeof(v_trans));
io.write(&n_layer, sizeof(n_layer));
io.write(&s_trans, sizeof(s_trans));
io.write(&n_layer, sizeof(n_layer));
std::vector<uint8_t> tmp_buf;
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Write key type
const int32_t k_type_i = (int32_t)k_l[il]->type;
io.write(&k_type_i, sizeof(k_type_i));
const int32_t r_type_i = (int32_t)r_l[il]->type;
io.write(&r_type_i, sizeof(r_type_i));
// Write row size of key
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
io.write(&k_size_row, sizeof(k_size_row));
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
io.write(&r_size_row, sizeof(r_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * k_size_row;
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
const size_t buf_size = range_size * r_size_row;
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
}
}
if (!v_trans) {
if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type;
io.write(&v_type_i, sizeof(v_type_i));
const int32_t s_type_i = (int32_t)s_l[il]->type;
io.write(&s_type_i, sizeof(s_type_i));
// Write row size of value
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
io.write(&v_size_row, sizeof(v_size_row));
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
io.write(&s_size_row, sizeof(s_size_row));
// Read each range of cells of v_size length each into tmp_buf and write out
// Read each range of cells of s_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t buf_size = range_size * v_size_row;
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
const size_t buf_size = range_size * s_size_row;
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
}
}
} else {
// When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = size;
const uint32_t mem_size = size;
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_s = hparams.n_embd_s();
// Write value type
const int32_t v_type_i = (int32_t)v_l[il]->type;
io.write(&v_type_i, sizeof(v_type_i));
const int32_t s_type_i = (int32_t)s_l[il]->type;
io.write(&s_type_i, sizeof(s_type_i));
// Write element size
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
io.write(&v_size_el, sizeof(v_size_el));
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
io.write(&s_size_el, sizeof(s_size_el));
// Write GQA embedding size
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
io.write(&n_embd_s, sizeof(n_embd_s));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
for (uint32_t j = 0; j < n_embd_s; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
const size_t buf_size = range_size * v_size_el;
io.write_tensor(v_l[il], src_offset, buf_size);
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
const size_t buf_size = range_size * s_size_el;
io.write_tensor(s_l[il], src_offset, buf_size);
}
}
}
}
}
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) {
// single sequence
seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch;
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
llama_batch_allocr balloc(hparams.n_pos_per_embd());
batch.n_tokens = cell_count;
batch.n_seq_tokens = cell_count;
batch.n_seqs = 1;
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
@@ -841,12 +849,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
return false;
}
batch.pos[i] = pos;
ubatch.pos[i] = pos;
}
batch.n_seq_id[0] = 1;
batch.seq_id[0] = &dest_seq_id;
ubatch.n_seq_id[0] = 1;
ubatch.seq_id[0] = &dest_seq_id;
if (!find_slot(batch)) {
if (!find_slot(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false;
}
@@ -854,8 +862,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(head + cell_count <= size);
GGML_ASSERT(cells[head].pos == batch.pos[0]);
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
} else {
@@ -869,7 +877,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
clear(true);
for (uint32_t i = 0; i < cell_count; ++i) {
kv_cell & cell = cells[i];
auto & cell = cells[i];
llama_pos pos;
uint32_t n_seq_id;
@@ -883,7 +891,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
// TODO: llama_memory_recurrent should have a notion of max sequences
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
if (seq_id < 0) {
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
@@ -915,10 +923,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
return true;
}
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t v_trans;
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
uint32_t s_trans;
uint32_t n_layer;
io.read_to(&v_trans, sizeof(v_trans));
io.read_to(&s_trans, sizeof(s_trans));
io.read_to(&n_layer, sizeof(n_layer));
if (n_layer != hparams.n_layer) {
@@ -929,102 +937,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
return false;
}
if (false != (bool) v_trans) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
if (false != (bool) s_trans) {
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
return false;
}
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
// Read type of key
int32_t k_type_i_ref;
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t) k_l[il]->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
int32_t r_type_i_ref;
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
const int32_t r_type_i = (int32_t) r_l[il]->type;
if (r_type_i != r_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
return false;
}
// Read row size of key
uint64_t k_size_row_ref;
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
uint64_t r_size_row_ref;
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
if (r_size_row != r_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
return false;
}
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
}
}
if (!v_trans) {
if (!s_trans) {
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
}
// Read row size of value
uint64_t v_size_row_ref;
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
uint64_t s_size_row_ref;
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
if (s_size_row != s_size_row_ref) {
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
return false;
}
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
}
}
} else {
// For each layer, read the values for each cell (transposed)
for (uint32_t il = 0; il < n_layer; ++il) {
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
const uint32_t n_embd_s = hparams.n_embd_s();
// Read type of value
int32_t v_type_i_ref;
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)v_l[il]->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
int32_t s_type_i_ref;
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
const int32_t s_type_i = (int32_t)s_l[il]->type;
if (s_type_i != s_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
return false;
}
// Read element size of value
uint32_t v_size_el_ref;
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
const size_t v_size_el = ggml_type_size(v_l[il]->type);
if (v_size_el != v_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
uint32_t s_size_el_ref;
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
const size_t s_size_el = ggml_type_size(s_l[il]->type);
if (s_size_el != s_size_el_ref) {
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
return false;
}
// Read GQA embedding size
uint32_t n_embd_v_gqa_ref;
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
// Read state embedding size
uint32_t n_embd_s_ref;
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
if (n_embd_s != n_embd_s_ref) {
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
return false;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (head + j * size) * v_size_el;
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
for (uint32_t j = 0; j < n_embd_s; ++j) {
const size_t dst_offset = (head + j * size) * s_size_el;
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
}
}
}
@@ -1034,25 +1040,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
}
//
// llama_kv_cache_recurrent_state
// llama_memory_recurrent_context
//
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
bool llama_kv_cache_recurrent_state::next() {
bool llama_memory_recurrent_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) {
@@ -1062,54 +1065,48 @@ bool llama_kv_cache_recurrent_state::next() {
return true;
}
bool llama_kv_cache_recurrent_state::apply() {
bool llama_memory_recurrent_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
kv->find_slot(ubatches[i_next]);
mem->find_slot(ubatches[i_next]);
return true;
}
std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return sbatch.out_ids;
}
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
llama_memory_status llama_memory_recurrent_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
return is_full ? kv->size : kv->n;
uint32_t llama_memory_recurrent_context::get_n_rs() const {
return is_full ? mem->size : mem->n;
}
uint32_t llama_kv_cache_recurrent_state::get_head() const {
return is_full ? 0 : kv->head;
uint32_t llama_memory_recurrent_context::get_head() const {
return is_full ? 0 : mem->head;
}
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
return is_full ? 0 : kv->rs_z;
int32_t llama_memory_recurrent_context::get_rs_z() const {
return is_full ? 0 : mem->rs_z;
}
uint32_t llama_kv_cache_recurrent_state::get_size() const {
return kv->size;
uint32_t llama_memory_recurrent_context::get_size() const {
return mem->size;
}
ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
return kv->k_l[il];
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
return mem->r_l[il];
}
ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
return kv->v_l[il];
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
return mem->s_l[il];
}
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
return kv->cells[i + kv->head].src0;
int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0;
}
@@ -8,35 +8,40 @@
#include <vector>
//
// llama_kv_cache_recurrent
// llama_memory_recurrent
//
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_memory_i {
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
~llama_kv_cache_recurrent() = default;
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max);
~llama_memory_recurrent() = default;
//
// llama_memory_i
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
void clear(bool data) override;
@@ -51,7 +56,7 @@ public:
bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there
// find a contiguous slot of memory cells and emplace the ubatch there
bool find_slot(const llama_ubatch & ubatch);
bool get_can_shift() const override;
@@ -72,7 +77,7 @@ public:
int32_t rs_z = -1;
// TODO: optimize for recurrent state needs
struct kv_cell {
struct mem_cell {
llama_pos pos = -1;
int32_t src = -1; // used to know where states should be copied from
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
@@ -88,15 +93,16 @@ public:
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
bool is_same_seq(const mem_cell & other) const {
return seq_id == other.seq_id;
}
};
std::vector<kv_cell> cells;
std::vector<mem_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
// per layer
std::vector<ggml_tensor *> r_l;
std::vector<ggml_tensor *> s_l;
private:
//const llama_model & model;
@@ -109,8 +115,8 @@ private:
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
size_t size_r_bytes() const;
size_t size_s_bytes() const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -119,57 +125,50 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
class llama_memory_recurrent_context : public llama_memory_context_i {
public:
// used for errors
llama_kv_cache_recurrent_state(llama_memory_status status);
llama_memory_recurrent_context(llama_memory_status status);
// used to create a full-cache state
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv);
// used to create a full-cache or update context
llama_memory_recurrent_context(
llama_memory_recurrent * mem);
// used to create a state from a batch
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
// used to create a batch processing context from a batch
llama_memory_recurrent_context(
llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_recurrent_state();
virtual ~llama_memory_recurrent_context();
//
// llama_memory_state_i
// llama_memory_context_i
//
bool next() override;
bool apply() override;
std::vector<int64_t> & out_ids() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_recurrent_state specific API
// llama_memory_recurrent_context specific API
//
uint32_t get_n_kv() const;
uint32_t get_n_rs() const;
uint32_t get_head() const;
int32_t get_rs_z() const;
uint32_t get_size() const;
ggml_tensor * get_k_l(int32_t il) const;
ggml_tensor * get_v_l(int32_t il) const;
ggml_tensor * get_r_l(int32_t il) const;
ggml_tensor * get_s_l(int32_t il) const;
int32_t s_copy(int i) const;
private:
const llama_memory_status status;
llama_kv_cache_recurrent * kv;
llama_sbatch sbatch;
llama_memory_recurrent * mem;
size_t i_next = 0;
+18 -22
View File
@@ -3,10 +3,11 @@
#include "llama.h"
#include <memory>
#include <vector>
struct llama_ubatch;
class llama_batch_allocr;
class llama_io_write_i;
class llama_io_read_i;
@@ -26,23 +27,21 @@ enum llama_memory_status {
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
};
// helper function for combining the status of two memory states
// helper function for combining the status of two memory contexts
// useful for implementing hybrid memory types (e.g. iSWA)
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
// the interface for managing the memory state during batch processing
// the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state
// - llama_kv_cache_unified_iswa_state
// - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_context
// ...
//
// the only method that can mutate the memory and the memory state is llama_memory_i::apply()
//
// TODO: rename to llama_memory_context_i ?
struct llama_memory_state_i {
virtual ~llama_memory_state_i() = default;
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
struct llama_memory_context_i {
virtual ~llama_memory_context_i() = default;
// consume the current ubatch from the state and proceed to the next one
// consume the current ubatch from the context and proceed to the next one
// return false if we are done
virtual bool next() = 0;
@@ -50,17 +49,14 @@ struct llama_memory_state_i {
// return false on failure
virtual bool apply() = 0;
// TODO: this might get reworked in the future when refactoring llama_batch
virtual std::vector<int64_t> & out_ids() = 0;
// get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0;
// get the status of the memory state - used for error handling and checking if any updates would be applied
// get the status of the memory context - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0;
};
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
@@ -68,19 +64,19 @@ struct llama_memory_i {
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
// return a context object containing the ubatches and memory state required to process them
// check the llama_memory_context_i::get_status() for the result
virtual llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
+1
View File
@@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
+554 -527
View File
File diff suppressed because it is too large Load Diff
+79 -4
View File
@@ -1,5 +1,4 @@
#include "llama-quant.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-model-loader.h"
@@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
}
}
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
if (prune.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const int blk = std::stoi(match[1]);
std::string new_name = orig_name;
if (mapped.count(blk)) {
// Already mapped, do nothing
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
mapped[blk] = "";
} else if (blk < prune.front()) {
mapped[blk] = std::to_string(blk);
next_id = blk + 1;
} else {
mapped[blk] = std::to_string(next_id);
++next_id;
}
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
}
return orig_name;
}
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
if (mapped.empty()) {
return orig_name;
}
static const std::regex pattern(R"(blk\.(\d+)\.)");
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
const std::string blk(match[1]);
std::string new_name = orig_name;
for (const auto & p : mapped) {
if (p.second == blk) {
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
}
}
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
}
return orig_name;
}
struct quantize_state_impl {
const llama_model & model;
const llama_model_quantize_params * params;
@@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const size_t align = GGUF_DEFAULT_ALIGNMENT;
gguf_context_ptr ctx_out { gguf_init_empty() };
std::vector<int> prune_list = {};
if (params->prune_layers) {
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
}
// copy the KV pairs from the input file
gguf_set_kv (ctx_out.get(), ml.meta.get());
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
@@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
}
std::map<int, std::string> mapped;
int blk_id = 0;
int pruned_attention_w = 0;
// make a list of weights
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
tensors.reserve(ml.weights_map.size());
for (const auto & it : ml.weights_map) {
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
if (remapped_name.empty()) {
if (it.first.find("attn_v.weight") != std::string::npos ||
it.first.find("attn_qkv.weight") != std::string::npos ||
it.first.find("attn_kv_b.weight") != std::string::npos) {
pruned_attention_w++;
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
tensors.push_back(&it.second);
}
if (!prune_list.empty()) {
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
}
// keep_split requires that the weights are sorted by split index
if (params->keep_split) {
@@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (llama_model_has_encoder(&model)) {
n_attn_layer *= 3;
}
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;
@@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
for (size_t i = 0; i < ctx_outs.size(); ++i) {
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
}
}
@@ -832,7 +906,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
const float * imatrix = nullptr;
if (imatrix_data) {
auto it = imatrix_data->find(tensor->name);
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
if (it == imatrix_data->end()) {
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
} else {
@@ -947,6 +1021,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
/*.tensor_type =*/ nullptr,
/*.prune_layers =*/ nullptr
};
return result;

Some files were not shown because too many files have changed in this diff Show More