Compare commits

..

28 Commits

Author SHA1 Message Date
Francis Couture-Harpin 2ff3354c33 memory : fix broken batch splits for recurrent cache
Splits producing more than one ubatch per batch for recurrent models
were broken with #14512.

This fixes it by moving the completeness check after the ubatch split loop.
2025-07-07 21:23:14 -04:00
Sigbjørn Skjæret e1a7059053 llama : fix incorrect minicpm3 v_states shape (#14571) 2025-07-07 23:35:35 +02:00
Sigbjørn Skjæret 12f55c302b llama : remove ggml_cont where possible (#14568) 2025-07-07 21:35:08 +02:00
Aman Gupta b9c3eefde1 CUDA: add bf16 and i32 to getrows (#14529) 2025-07-07 21:45:43 +08:00
Eve 6491d6e4f1 vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (#14485)
Commit taken from remyoudompheng's PR https://github.com/ggml-org/llama.cpp/pull/12260

Co-authored-by: Rémy Oudompheng <remyoudompheng@gmail.com>
2025-07-06 12:29:36 +02:00
Jeff Bolz e592be1575 vulkan: fix rms_norm+mul fusion (#14545)
The fused operation was grabbing the epsilon value from the wrong place.

Add an env var to disable fusion.

Add some missing checks for supported shapes/types.

Handle fused rms_norm+mul in check_results.
2025-07-06 10:08:16 +02:00
Jeff Bolz a0374a67e2 vulkan: Handle updated FA dim2/3 definition (#14518)
* vulkan: Handle updated FA dim2/3 definition

Pack mask boolean and n_head_log2 into a single dword to keep the push
constant block under the 128B limit.

* handle null mask for gqa

* allow gqa with dim3>1
2025-07-05 09:26:04 +02:00
Sigbjørn Skjæret ddef99522d server : fix assistant prefilling when content is an array (#14360) 2025-07-05 09:17:14 +02:00
Sigbjørn Skjæret 6681688146 opencl: add GELU_ERF (#14476) 2025-07-04 23:24:56 -07:00
Georgi Gerganov bac8bed248 eval-callback : check for empty input (#14539) 2025-07-05 07:18:09 +03:00
R0CKSTAR b81510a7b7 test-backend-ops: add support for specifying output format (#14368)
* test-backend-ops: add support for specifying output format

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Add build_commit and build_number in test_result

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* refactor

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Get build commit from ggml_commit()

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Merge errors into test_operation_info && address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* remove visitor nonsense

* remove visitor comment

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

* Address review comments

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>

---------

Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
2025-07-05 12:10:53 +08:00
Georgi Gerganov ef797db357 metal : disable fast math in all quantize kernels (#14528)
ggml-ci
2025-07-04 19:19:09 +03:00
Georgi Gerganov 67d1ef23c6 batch : add optional for sequential equal split (#14511)
ggml-ci
2025-07-04 09:08:59 +03:00
Georgi Gerganov 7b50f7c025 graph : prepare for 4D mask (#14515)
ggml-ci
2025-07-04 09:05:36 +03:00
Georgi Gerganov c79184d2d1 batch : add n_used count (#14512)
ggml-ci
2025-07-04 09:04:59 +03:00
luyhcsu 499a8f5a78 CANN: Replace aclrtMemsetSync with aclnnInplaceZero operator (#14002)
Co-authored-by: luyuhong <luyuhong@kylinos.cn>
2025-07-04 11:50:07 +08:00
Sigbjørn Skjæret 28657a8229 ggml : implement GEGLU_ERF and GEGLU_QUICK ops (#14445) 2025-07-03 23:07:22 +02:00
lhez bee28421be opencl : broadcast for soft_max (#14510) 2025-07-03 20:22:24 +02:00
Jeff Bolz 2b72bedec1 vulkan: support mixed/deepseekR1 FA head sizes (#14509)
* vulkan: better parameterize FA by head sizes

* vulkan: support mixed/deepseekR1 FA head sizes
2025-07-03 20:21:14 +02:00
Johannes Gäßler c8c4495b8d ggml: backward pass for split swiglu (#14483) 2025-07-03 17:05:18 +02:00
Nicolò Scipione 7b63a71a6b Fix conditional enabling following arch checks for ggml-sycl (#14504)
Signed-off-by: nscipione <nicolo.scipione@codeplay.com>
2025-07-03 11:00:03 +02:00
Xuan-Son Nguyen 0c2ee38ab7 convert : correct gemma 3n conversion (#14450)
* convert : correct gemma 3n conversion

* rm redundant code
2025-07-03 10:03:06 +02:00
Georgi Gerganov a70c8a0c4b kv-cache : use ggml_set_rows (#14285)
* kv-cache : use ggml_set_rows

ggml-ci

* graph : separate k and v indices

ggml-ci

* cont : remove redundant ifs

ggml-ci

* kv-cache : improve find_slot impl

* kv-cache : bounds-check when accessing slot_info indices

* kv-cache : add comments

ggml-ci

* ggml : add TODOs for adding GGML_OP_SET_ROWS support in the backends

ggml-ci
2025-07-03 10:53:35 +03:00
Georgi Gerganov 9067487c44 ggml : fix FA mask dim 2 and 3 (#14505)
* ggml : fix FA mask dim 2 and 3

ggml-ci

* backends : unsupport batched FA in CUDA and Vulkan

ggml-ci

* vulkan : disable FA for mask->ne[2] != 1
2025-07-03 10:46:57 +03:00
Georgi Gerganov d4cdd9c1c3 ggml : remove kompute backend (#14501)
ggml-ci
2025-07-03 07:48:32 +03:00
Aman Gupta 55c2646b45 CUDA: add dynamic shared mem to softmax, refactor general usage (#14497) 2025-07-03 07:45:11 +08:00
Sigbjørn Skjæret e75ba4c043 gguf-py : add support for chat template jinja files (#14508)
* add support for chat template jinja files

* remove gemma3n hack
2025-07-02 21:02:35 +02:00
compilade 5d46babdc2 llama : initial Mamba-2 support (#9126)
* llama : initial Mamba-2 support

* ggml : SIMD ggml_ssm_scan for Mamba-2

* ggml : improve ggml_mul speed when masking recurrent states

* llama : support running Mamba-Codestral-7B-v0.1

* llama : fix Mamba-2 conv state saving

* ggml : make the ggml_mul fast broadcast path more consistently formatted

* llama : remove unused variable

* llama : add missing break

* convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present

The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires
workarounds to work correctly.

* llama : avoid redundant state copy for Mamba 1 and 2

* metal : attempt to adapt SSM_SCAN for Mamba-2

* metal : fix SSM_SCAN pipeline scope

* metal : use log and exp instead of log1pf and expf in SSM_SCAN

* metal : remove unused arguments for SSM_SCAN

The max index is 31, so trimming the arguments is necessary.

* metal : add back n_seqs to SSM_SCAN args

Whoops, this is needed for the offset in the concatenated output.

* metal : fix SSM_SCAN state head offset

* metal : fix wrong number of tokens per sequence in SSM_SCAN

* ggml : remove unused fast broadcast path in GGML_MUL

This was initially added because states were masked with ggml_mul,
but this is no longer done and so this "optimisation" is no longer
necessary, or at least not worth the additional code complexity.

* ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks

* convert : fix flake8 lint

* metal : fix confusion between ; and ,

* metal : add missing args for nb references in ssm_scan_f32_group

* metal : single-user mamba2 inference works

* kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.

* convert : avoid AutoConfig for Mamba and Mamba2 hparams

* kv-cache : allow context shift for recurrent models

* graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.

* ggml : fix mamba2 ssm scan when compiled with SVE

* ggml-cpu : reorder SVE FMA for consistency with other SIMD arches

* cuda : implement ssm scan for Mamba2

There is still room for improvement, but it works!

* cuda : adapt Mamba1 ssm scan to shape changes from Mamba2

* mamba : fix mismatched new and delete size for llm_build_mamba

Subclasses of llm_graph_context cannot have extra fields,
because the called destructor is not the one from the subclass.
This otherwise would cause problems when runnning Mamba-(1|2) inference
when compiled -DGGML_SANITIZE_ADDRESS=ON

* cuda : graceful fallback for Mamba-1 models with weird embd size
2025-07-02 13:10:24 -04:00
121 changed files with 3876 additions and 5494 deletions
@@ -40,7 +40,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
multiple: true
validations:
required: true
+1 -1
View File
@@ -42,7 +42,7 @@ body:
attributes:
label: GGML backends
description: Which GGML backends do you know to be affected?
options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL]
multiple: true
validations:
required: true
-6
View File
@@ -1,10 +1,4 @@
# https://github.com/actions/labeler
Kompute:
- changed-files:
- any-glob-to-any-file:
- ggml/include/ggml-kompute.h
- ggml/src/ggml-kompute/**
- README-kompute.md
Apple Metal:
- changed-files:
- any-glob-to-any-file:
+1 -10
View File
@@ -740,9 +740,6 @@ jobs:
- build: 'llvm-arm64-opencl-adreno'
arch: 'arm64'
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON'
# - build: 'kompute-x64'
# arch: 'x64'
# defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
steps:
- name: Clone
@@ -756,12 +753,6 @@ jobs:
variant: ccache
evict-old-files: 1d
- name: Clone Kompute submodule
id: clone_kompute
if: ${{ matrix.build == 'kompute-x64' }}
run: |
git submodule update --init ggml/src/ggml-kompute/kompute
- name: Download OpenBLAS
id: get_openblas
if: ${{ matrix.build == 'openblas-x64' }}
@@ -777,7 +768,7 @@ jobs:
- name: Install Vulkan SDK
id: get_vulkan
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
if: ${{ matrix.build == 'vulkan-x64' }}
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
-3
View File
@@ -1,3 +0,0 @@
[submodule "kompute"]
path = ggml/src/ggml-kompute/kompute
url = https://github.com/nomic-ai/kompute.git
-1
View File
@@ -120,7 +120,6 @@ endfunction()
llama_option_depr(FATAL_ERROR LLAMA_CUBLAS GGML_CUDA)
llama_option_depr(WARNING LLAMA_CUDA GGML_CUDA)
llama_option_depr(WARNING LLAMA_KOMPUTE GGML_KOMPUTE)
llama_option_depr(WARNING LLAMA_METAL GGML_METAL)
llama_option_depr(WARNING LLAMA_METAL_EMBED_LIBRARY GGML_METAL_EMBED_LIBRARY)
llama_option_depr(WARNING LLAMA_NATIVE GGML_NATIVE)
+111 -4
View File
@@ -4408,9 +4408,6 @@ class Gemma3NModel(Gemma3Model):
]
def set_vocab(self):
with open(self.dir_model / "chat_template.jinja") as f:
# quick hack to make sure chat template is added
self.gguf_writer.add_chat_template(f.read())
super().set_vocab()
def set_gguf_parameters(self):
@@ -4781,6 +4778,14 @@ class ARwkv7Model(Rwkv7Model):
class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
@@ -4855,6 +4860,100 @@ class MambaModel(TextModel):
return [(new_name, data_torch)]
@ModelBase.register("Mamba2ForCausalLM")
class Mamba2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA2
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
if (self.dir_model / "tokenizer.model").is_file():
self._set_vocab_sentencepiece()
elif (self.dir_model / "tokenizer.model.v3").is_file():
# mamba-codestral
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
elif (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
self._set_vocab_builtin("gpt-neox", vocab_size)
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
self.gguf_writer.add_ssm_group_count(n_group)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
# map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2
name = name.removeprefix("model.")
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
new_name = self.map_tensor_name(name)
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
data_torch = data_torch.squeeze()
elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
if name.endswith(".A_log"):
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
yield (new_name, data_torch)
@ModelBase.register("CohereForCausalLM")
class CommandR2Model(TextModel):
model_arch = gguf.MODEL_ARCH.COMMAND_R
@@ -6615,12 +6714,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st
# maybe we should fallback to text model's arch in that case, since not many models have both
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
arch = None
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
arch = arches[0]
elif "ssm_cfg" in hparams:
# For non-hf Mamba and Mamba2 models
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
if arch is None:
raise ValueError("Failed to detect model architecture")
return arch
+5
View File
@@ -136,6 +136,11 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
-2
View File
@@ -181,7 +181,6 @@ option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug ou
option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF)
option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF)
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
@@ -266,7 +265,6 @@ set(GGML_PUBLIC_HEADERS
include/ggml-cann.h
include/ggml-cpp.h
include/ggml-cuda.h
include/ggml-kompute.h
include/ggml-opt.h
include/ggml-metal.h
include/ggml-rpc.h
-50
View File
@@ -1,50 +0,0 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
#define GGML_KOMPUTE_MAX_DEVICES 16
struct ggml_vk_device {
int index;
int type; // same as VkPhysicalDeviceType
size_t heapSize;
const char * name;
const char * vendor;
int subgroupSize;
uint64_t bufferAlignment;
uint64_t maxAlloc;
};
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
bool ggml_vk_has_vulkan(void);
bool ggml_vk_has_device(void);
struct ggml_vk_device ggml_vk_current_device(void);
//
// backend API
//
// forward declaration
typedef struct ggml_backend * ggml_backend_t;
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
#ifdef __cplusplus
}
#endif
+37 -7
View File
@@ -557,6 +557,8 @@ extern "C" {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
GGML_GLU_OP_COUNT,
};
@@ -1147,6 +1149,22 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: n columns, r rows,
// B: n columns, r rows,
GGML_API struct ggml_tensor * ggml_glu_split(
@@ -1170,6 +1188,16 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
@@ -1983,15 +2011,16 @@ extern "C" {
#define GGML_KQ_MASK_PAD 64
// q: [n_embd_k, n_batch, n_head, ne3]
// k: [n_embd_k, n_kv, n_head_kv, ne3]
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
// q: [n_embd_k, n_batch, n_head, ne3 ]
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
// ne3 % ne32 == 0
// n_head % ne32 == 0
// ne3 % ne33 == 0
//
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
@@ -2031,7 +2060,8 @@ extern "C" {
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C);
struct ggml_tensor * C,
struct ggml_tensor * ids);
// partition into non-overlapping windows with padding if needed
// example:
-1
View File
@@ -365,7 +365,6 @@ ggml_add_backend(BLAS)
ggml_add_backend(CANN)
ggml_add_backend(CUDA)
ggml_add_backend(HIP)
ggml_add_backend(Kompute)
ggml_add_backend(METAL)
ggml_add_backend(MUSA)
ggml_add_backend(RPC)
-8
View File
@@ -61,10 +61,6 @@
#include "ggml-cann.h"
#endif
#ifdef GGML_USE_KOMPUTE
#include "ggml-kompute.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
@@ -189,9 +185,6 @@ struct ggml_backend_registry {
#ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg());
#endif
#ifdef GGML_USE_KOMPUTE
register_backend(ggml_backend_kompute_reg());
#endif
#ifdef GGML_USE_CPU
register_backend(ggml_backend_cpu_reg());
#endif
@@ -575,7 +568,6 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_load_best("sycl", silent, dir_path);
+3 -1
View File
@@ -67,6 +67,7 @@
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_zero.h>
#include <float.h>
#include <cmath>
@@ -804,10 +805,11 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
nb[i] = nb[i - 1] * ne[i - 1];
}
ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
aclTensor* zero =
ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, zero);
return zero;
GGML_UNUSED(n_bytes);
}
/**
+6
View File
@@ -2086,6 +2086,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY: {
ggml_tensor *src = op->src[0];
if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
+2
View File
@@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
n_tasks = n_threads;
} break;
+472 -88
View File
@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
}
}
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_erf(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_erf_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_erf_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_geglu_quick
static void ggml_compute_forward_geglu_quick_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
const float v = GGML_FP16_TO_FP32(x);
GGML_UNUSED(v);
assert(!isnan(v));
assert(!isinf(v));
}
#endif
}
}
static void ggml_compute_forward_geglu_quick(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_geglu_quick_f32(params, dst);
} break;
case GGML_TYPE_F16:
{
ggml_compute_forward_geglu_quick_f16(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_norm
static void ggml_compute_forward_norm_f32(
@@ -7799,7 +8085,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memset(VKQ32, 0, DV*sizeof(float));
}
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
@@ -8337,120 +8623,210 @@ void ggml_compute_forward_ssm_conv(
static void ggml_compute_forward_ssm_scan_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // s
const ggml_tensor * src1 = dst->src[1]; // x
const ggml_tensor * src2 = dst->src[2]; // dt
const ggml_tensor * src3 = dst->src[3]; // A
const ggml_tensor * src4 = dst->src[4]; // B
const ggml_tensor * src5 = dst->src[5]; // C
const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
const int ith = params->ith;
const int nth = params->nth;
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // dim
const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1];
const int64_t nt = src1->ne[2]; // number of tokens per sequence
const int64_t ns = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
// can't use ggml_nbytes because src1 is not necessarily contiguous
const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
// allows optimizing the modulo since n_group should be a power of 2
GGML_ASSERT((ng & -ng) == ng);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// heads per thread
const int dh = (nh + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
const int ir = ir1 - ir0;
// head range for this thread
const int ih0 = dh*ith;
const int ih1 = MIN(ih0 + dh, nh);
#ifdef __ARM_FEATURE_SVE
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
const int32_t * ids = (const int32_t *) src6->data;
// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
for (int i3 = 0; i3 < ns; ++i3) {
const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
for (int i2 = 0; i2 < nt; ++i2) {
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]);
if (src3->ne[0] == 1) {
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
t1 = exp_ps_sve(svptrue_b32(), t1);
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
const float dA = expf(dt_soft_plus * A[h]);
vs0 = GGML_F32_VEC_FMA(vs0, t1, t2);
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
float sumf = 0.0f;
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int ggml_f32_epr = svcntw();
const int ggml_f32_step = 1 * ggml_f32_epr;
GGML_F32_VEC_STORE(&s[i1*nc + k], vs0);
}
y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector);
}
}
}
const int np = (nc & ~(ggml_f32_step - 1));
GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
for (int i = 0; i < np; i += ggml_f32_step) {
// TODO: maybe unroll more?
for (int j = 0; j < 1; j++) {
GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc);
t0 = GGML_F32_VEC_MUL(t0, adA);
t1 = GGML_F32_VEC_MUL(t1, axdt);
t0 = GGML_F32_VEC_ADD(t0, t1);
sum = GGML_F32_VEC_FMA(sum, t0, t2);
GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
}
}
sumf = GGML_F32xt_REDUCE_ONE(sum);
#else
for (int i3 = 0; i3 < n_s; ++i3) {
for (int i2 = 0; i2 < n_t; ++i2) {
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
const int np = (nc & ~(GGML_F32_STEP - 1));
// use the output as the source for the next token-wise iterations
if (i2 > 0) { s0 = s; }
GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
// d_inner
for (int i1 = 0; i1 < ir; ++i1) {
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
float x_dt = x[i1] * dt_soft_plus;
float sumf = 0.0f;
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
int i = i0 + i1*nc;
// state = prev_state * dA + dB * x
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[i0];
s[i] = state;
GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
GGML_F32_VEC ax[GGML_F32_ARR];
GGML_F32_VEC ay[GGML_F32_ARR];
GGML_F32_VEC az[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc);
ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F32_VEC_REDUCE(sumf, sum);
#endif
#else
const int np = 0;
#endif
// d_state
for (int i0 = np; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * dA) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf;
}
}
} else {
// Mamba-1 has an element-wise decay factor for the states
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
// dim
for (int i1 = 0; i1 < nr; ++i1) {
const int ii = i1 + h*nr;
const float x_dt = x[ii] * dt_soft_plus;
#if defined(__ARM_FEATURE_SVE)
svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
// d_state
// TODO: what happens when (d_state % svcntw()) != 0?
for (int64_t k = 0; k < nc; k += svcntw()) {
svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]);
svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]);
svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
t1 = exp_ps_sve(svptrue_b32(), t1);
svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
}
y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
#else
float sumf = 0.0f;
// NOTE: can't really use GGML_SIMD here because d_state is usually 16
// and also because expf is used within the loop.
// d_state
for (int i0 = 0; i0 < nc; ++i0) {
const int i = i0 + ii*nc;
const int ig = i0 + (h & (ng - 1))*nc;
// state = prev_state * dA + dB * x
const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
// y = rowwise_dotprod(state, C)
sumf += state * C[ig];
s[i] = state;
}
y[ii] = sumf;
#endif
}
y[i1] = sumf;
}
}
// use the output as the source when it's not the first token-wise iteration
s0 = s;
}
#endif
}
}
void ggml_compute_forward_ssm_scan(
@@ -8689,6 +9065,14 @@ void ggml_compute_forward_glu(
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
} break;
case GGML_GLU_OP_GEGLU_QUICK:
{
ggml_compute_forward_geglu_quick(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
+1 -1
View File
@@ -189,7 +189,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c)
#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
+9 -9
View File
@@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = 0; i < np; i += ggml_f32_step) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2);
sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3);
sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4);
sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5);
sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6);
sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7);
sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8);
sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8);
}
// leftovers
// Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop
@@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
if (np2 < n) {
+49 -9
View File
@@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr);
ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2);
ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx);
GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2);
ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr);
ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3);
ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx);
GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3);
ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr);
ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4);
ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx);
GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4);
ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr);
ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5);
ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx);
GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5);
ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr);
ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6);
ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx);
GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6);
ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr);
ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7);
ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx);
GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7);
ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr);
ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8);
ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx);
GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8);
}
@@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
for (int i = np; i < np2; i += ggml_f32_epr) {
ax1 = GGML_F32_VEC_LOAD(x + i);
ay1 = GGML_F32_VEC_LOAD(y + i);
ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1);
ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx);
GGML_F32_VEC_STORE(y + i, ay1);
}
@@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
}
}
inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
float xi = x[i];
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
}
}
inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float xi = GGML_CPU_FP16_TO_FP32(x[i]);
float gi = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
}
}
#ifdef GGML_GELU_QUICK_FP16
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
uint16_t t;
for (int i = 0; i < n; ++i) {
ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
}
}
#else
inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
}
}
#endif
inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
float v = GGML_CPU_FP16_TO_FP32(g[i]);
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
}
}
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
#ifndef GGML_USE_ACCELERATE
ggml_float sum = 0.0;
+14
View File
@@ -175,6 +175,20 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; \
const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \
} \
} while (0)
#else
#define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) do {} while (0)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
+2 -14
View File
@@ -123,13 +123,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
} else {
cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
@@ -175,13 +169,7 @@ void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_ten
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32<true>, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
} else {
cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+8
View File
@@ -168,6 +168,10 @@ static void ggml_cuda_get_rows_switch_src0_type(
get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_BF16:
get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
@@ -210,6 +214,10 @@ void get_rows_cuda(
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_I32:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_F16:
ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+27 -3
View File
@@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_cuda_op_geglu_quick(ctx, dst);
break;
default:
return false;
}
@@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;
@@ -3192,6 +3200,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_I32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -3321,9 +3331,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2
// (kernel only supports d_state == 128 && d_head % 16 == 0)
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
} else {
// Mamba
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF:
@@ -3377,7 +3400,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
+2 -8
View File
@@ -3016,14 +3016,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, false>), nbytes_shared);
CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, MMQ_NWARPS, true>), nbytes_shared);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x;
+38 -40
View File
@@ -2,6 +2,7 @@
#include "ggml.h"
#include "softmax.cuh"
#include <cstdint>
#include <utility>
template <typename T>
static __device__ __forceinline__ float t2f32(T val) {
@@ -181,6 +182,37 @@ static __global__ void soft_max_back_f32(
}
}
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, p);
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
@@ -193,46 +225,12 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, dst, params);
break;
}
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
+180 -51
View File
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0);
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t d_inner, const int64_t L) {
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D
const int bidx = blockIdx.x; // split along B (sequences)
const int bidy = blockIdx.y; // split along D (d_inner)
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
const int stride_s0 = src0_nb1 / sizeof(float);
const int stride_x = src1_nb1 / sizeof(float);
const int stride_s0 = src0_nb2 / sizeof(float);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb1 / sizeof(float);
const int stride_C = src5_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = stride_x;
const int stride_y = d_inner;
// can N not be 16? for example 32?
if (N == 16) {
@@ -84,24 +83,156 @@ __global__ void __launch_bounds__(splitD, 2)
}
}
// assumes as many threads as d_state
template <int splitH, int d_state>
__global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int head_idx = (blockIdx.x * splitH) / d_head;
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[splitH];
// for the parallel accumulation
__shared__ float stateC[splitH * d_state];
#pragma unroll
for (int j = 0; j < splitH; j++) {
state[j] = s0_block[j * d_state + threadIdx.x];
}
for (int64_t i = 0; i < n_tok; i++) {
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// TODO: only calculate B and C once per head group
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
// across d_head
#pragma unroll
for (int j = 0; j < splitH; j++) {
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B * x_dt);
stateC[j * d_state + threadIdx.x] = state[j] * C;
}
__syncthreads();
// parallel accumulation for stateC
// TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
}
}
// write back the state
#pragma unroll
for (int j = 0; j < splitH; j++) {
s_block[j * d_state + threadIdx.x] = state[j];
}
}
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
const float * src4, const float * src5, const int32_t * src6, float * dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) {
const int threads = 128;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0);
const dim3 blocks(B, (D + threads - 1) / threads, 1);
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba-2
if (d_state == 128) {
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
const int splitH = 16;
GGML_ASSERT(head_dim % splitH == 0);
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=128.");
}
} else {
GGML_ABORT("doesn't support N!=16.");
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
if (d_state == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
} else {
GGML_ABORT("doesn't support d_state!=16.");
}
}
}
@@ -112,30 +243,25 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const struct ggml_tensor * src6 = dst->src[6]; // ids
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
const int64_t nr = src0->ne[1]; // head_dim or 1
const int64_t nh = src1->ne[1]; // n_head
const int64_t ng = src4->ne[1]; // n_group
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
@@ -143,13 +269,16 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data;
const int32_t * src6_d = (const int32_t *) src6->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
s_off, nc, nr, nh, ng, n_t, n_s, stream);
}
+8
View File
@@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
}
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
}
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
+4
View File
@@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
-166
View File
@@ -1,166 +0,0 @@
find_package(Vulkan COMPONENTS glslc REQUIRED)
find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
if (NOT glslc_executable)
message(FATAL_ERROR "glslc not found")
endif()
ggml_add_backend_library(ggml-kompute
ggml-kompute.cpp
../../include/ggml-kompute.h
)
target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
function(compile_shader)
set(options)
set(oneValueArgs)
set(multiValueArgs SOURCES)
cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
foreach(source ${compile_shader_SOURCES})
get_filename_component(filename ${source} NAME)
set(spv_file ${filename}.spv)
add_custom_command(
OUTPUT ${spv_file}
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
COMMENT "Compiling ${source} to ${spv_file}"
)
get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
set(FILE_NAME "shader${RAW_FILE_NAME}")
string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
if(CMAKE_GENERATOR MATCHES "Visual Studio")
add_custom_command(
OUTPUT ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
DEPENDS ${spv_file} xxd
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
)
else()
add_custom_command(
OUTPUT ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
DEPENDS ${spv_file} xxd
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
)
endif()
endforeach()
endfunction()
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
message(STATUS "Kompute found")
set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
add_subdirectory(kompute)
# Compile our shaders
compile_shader(SOURCES
kompute-shaders/op_scale.comp
kompute-shaders/op_scale_8.comp
kompute-shaders/op_add.comp
kompute-shaders/op_addrow.comp
kompute-shaders/op_mul.comp
kompute-shaders/op_silu.comp
kompute-shaders/op_relu.comp
kompute-shaders/op_gelu.comp
kompute-shaders/op_softmax.comp
kompute-shaders/op_norm.comp
kompute-shaders/op_rmsnorm.comp
kompute-shaders/op_diagmask.comp
kompute-shaders/op_mul_mat_mat_f32.comp
kompute-shaders/op_mul_mat_f16.comp
kompute-shaders/op_mul_mat_q8_0.comp
kompute-shaders/op_mul_mat_q4_0.comp
kompute-shaders/op_mul_mat_q4_1.comp
kompute-shaders/op_mul_mat_q4_k.comp
kompute-shaders/op_mul_mat_q6_k.comp
kompute-shaders/op_getrows_f32.comp
kompute-shaders/op_getrows_f16.comp
kompute-shaders/op_getrows_q4_0.comp
kompute-shaders/op_getrows_q4_1.comp
kompute-shaders/op_getrows_q6_k.comp
kompute-shaders/op_rope_norm_f16.comp
kompute-shaders/op_rope_norm_f32.comp
kompute-shaders/op_rope_neox_f16.comp
kompute-shaders/op_rope_neox_f32.comp
kompute-shaders/op_cpy_f16_f16.comp
kompute-shaders/op_cpy_f16_f32.comp
kompute-shaders/op_cpy_f32_f16.comp
kompute-shaders/op_cpy_f32_f32.comp
)
# Create a custom target for our generated shaders
add_custom_target(generated_shaders DEPENDS
shaderop_scale.h
shaderop_scale_8.h
shaderop_add.h
shaderop_addrow.h
shaderop_mul.h
shaderop_silu.h
shaderop_relu.h
shaderop_gelu.h
shaderop_softmax.h
shaderop_norm.h
shaderop_rmsnorm.h
shaderop_diagmask.h
shaderop_mul_mat_mat_f32.h
shaderop_mul_mat_f16.h
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h
shaderop_mul_mat_q4_1.h
shaderop_mul_mat_q4_k.h
shaderop_mul_mat_q6_k.h
shaderop_getrows_f32.h
shaderop_getrows_f16.h
shaderop_getrows_q4_0.h
shaderop_getrows_q4_1.h
shaderop_getrows_q6_k.h
shaderop_rope_norm_f16.h
shaderop_rope_norm_f32.h
shaderop_rope_neox_f16.h
shaderop_rope_neox_f32.h
shaderop_cpy_f16_f16.h
shaderop_cpy_f16_f32.h
shaderop_cpy_f32_f16.h
shaderop_cpy_f32_f32.h
)
# Create a custom command that depends on the generated_shaders
add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
DEPENDS generated_shaders
COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
)
# Add the stamp to the main sources to ensure dependency tracking
target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
else()
message(WARNING "Kompute not found")
endif()
File diff suppressed because it is too large Load Diff
@@ -1,112 +0,0 @@
#extension GL_EXT_shader_16bit_storage: require
#extension GL_EXT_shader_8bit_storage: require
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
#extension GL_EXT_control_flow_attributes: enable
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
#define QK4_0 32
#define QK4_1 32
#define GELU_COEF_A 0.044715
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
#define TWOPI_F 6.283185307179586f
#define QK_K 256
#define K_SCALE_SIZE 12
#define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx])
#define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx)
#define u8BufToU32(buf, idx) (((uint32_t u8BufToU16(buf, idx + 2) << 8 | buf[idx + 1]) << 8) | buf[idx])
#define u8BufToFloat(buf, idx) uintBitsToFloat u8BufToU32(buf, idx)
#define sizeof_block_q4_0 0x12
struct block_q4_0 {
float16_t d;
uint8_t qs[QK4_0 / 2];
};
mat4 dequantize_q4_0(const block_q4_0 xb, uint il) {
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
const float d2 = d1 / 256.f;
const float md = -8.f * xb.d;
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
const uint16_t mask1 = mask0 << 8;
mat4 reg;
for (int i=0;i<8;i++) {
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
reg[i/2][2*(i%2)+0] = d1 * (b & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (b & mask1) + md;
}
return reg;
}
#define sizeof_block_q4_1 0x14
struct block_q4_1 {
float16_t d;
float16_t m;
uint8_t qs[QK4_1 / 2];
};
mat4 dequantize_q4_1(const block_q4_1 xb, uint il) {
const float d1 = il != 0 ? (xb.d / 16.f) : xb.d;
const float d2 = d1 / 256.f;
const float m = xb.m;
const uint16_t mask0 = il != 0 ? uint16_t(0x00F0) : uint16_t(0x000F);
const uint16_t mask1 = mask0 << 8;
mat4 reg;
for (int i=0;i<8;i++) {
uint16_t b = (uint16_t(xb.qs[2 * i + 1]) << 8) | uint16_t(xb.qs[2 * i]);
reg[i/2][2*(i%2)+0] = ((b & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = ((b & mask1) * d2) + m;
}
return reg;
}
#define sizeof_block_q4_k 144
struct block_q4_k {
float16_t d;
float16_t dmin;
uint8_t scales[K_SCALE_SIZE];
uint8_t qs[QK_K/2];
};
#define sizeof_block_q6_k 210
struct block_q6_k {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
float16_t d; // super-block scale
};
mat4 dequantize_q6_k(const block_q6_k xb, uint il) {
const float16_t d_all = xb.d;
const uint qlIndex = 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
const uint qhIndex = 32*(il/8) + 16*(il&1);
float16_t sc = xb.scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
const uint16_t kmask1 = il>1 ? uint16_t(il>2 ? 192 : 48) : uint16_t(il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? uint8_t(0xF0) : uint8_t(0x0F);
const float16_t coef = il>1 ? float16_t(1.f/16.f) : float16_t(1.f);
const float16_t ml = float16_t(d_all * sc * 32.f);
const float16_t dl = float16_t(d_all * sc * coef);
mat4 reg;
for (int i = 0; i < 16; ++i) {
const float16_t q = (il&1) != 0 ? ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 2))
: ((xb.ql[qlIndex + i] & kmask2) | ((xb.qh[qhIndex + i] & kmask1) << 4));
reg[i/4][i%4] = dl * q - ml;
}
return reg;
}
#define QK8_0 32
// struct block_q8_0 {
// float16_t d; // delta
// int8_t qs[QK8_0]; // quants
// };
#define sizeof_block_q8_0 34
@@ -1,58 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1024) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb00;
int nb01;
int nb02;
int nb03;
int ne10;
int ne11;
int ne12;
int ne13;
int nb10;
int nb11;
int nb12;
int nb13;
int ne0;
int nb0;
int nb1;
int nb2;
int nb3;
//int offs; // TODO: needed for GGML_OP_ACC, see metal code
} pcs;
// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint i13 = i03 % pcs.ne13;
const uint i12 = i02 % pcs.ne12;
const uint i11 = i01 % pcs.ne11;
int offs = 0; // TMP (see above)
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + offs) / 4);
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11 ) / 4);
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1 + offs) / 4);
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
const uint i10 = i0 % pcs.ne10;
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] + inB[pcs.inBOff + src1_off + i10];
}
}
@@ -1,25 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
uint row;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = inA[i + pcs.inAOff] + inB[(i % pcs.row) + pcs.inBOff];
}
}
@@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float16_t
#define IN_TYPE_SIZE 2
#define OUT_TYPE float16_t
#define OUT_TYPE_SIZE 2
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}
@@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float16_t
#define IN_TYPE_SIZE 2
#define OUT_TYPE float
#define OUT_TYPE_SIZE 4
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}
@@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float
#define IN_TYPE_SIZE 4
#define OUT_TYPE float16_t
#define OUT_TYPE_SIZE 2
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}
@@ -1,52 +0,0 @@
#version 450
#include "common.comp"
#define IN_TYPE float
#define IN_TYPE_SIZE 4
#define OUT_TYPE float
#define OUT_TYPE_SIZE 4
layout(local_size_x = 1024) in;
layout (binding = 0) readonly buffer tensorIn { IN_TYPE in_[]; };
layout (binding = 1) writeonly buffer tensorOut { OUT_TYPE out_[]; };
layout (push_constant) uniform parameter {
uint inOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
int ne1;
int ne2;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const int n = int(i03)*pcs.ne02*pcs.ne01*pcs.ne00 + int(i02)*pcs.ne01*pcs.ne00 + int(i01)*pcs.ne00;
const int i3 = n / (pcs.ne2*pcs.ne1*pcs.ne0);
const int i2 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0) / (pcs.ne1*pcs.ne0);
const int i1 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0) / pcs.ne0;
const int i0 = (n - i3*pcs.ne2*pcs.ne1*pcs.ne0 - i2*pcs.ne1*pcs.ne0 - i1*pcs.ne0);
const uint dst_data = (i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / OUT_TYPE_SIZE + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
const uint src = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01 + i00*pcs.nb00) / IN_TYPE_SIZE) + pcs.inOff; // Based from in_
out_[dst_data+i00] = OUT_TYPE(in_[src]);
}
}
@@ -1,30 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint n_past;
int ne00;
int ne01;
} pcs;
void main() {
const uint i02 = gl_WorkGroupID.z;
const uint i01 = gl_WorkGroupID.y;
const uint i00 = gl_WorkGroupID.x;
const uint index = i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00 + i00;
if (i00 > pcs.n_past + i01) {
out_[index + pcs.outOff] = uintBitsToFloat(0xFF800000);
} else {
out_[index + pcs.outOff] = in_[index + pcs.inOff];
}
}
@@ -1,22 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = 0.5*y*(1.0 + tanh(clamp(SQRT_2_OVER_PI*y*(1.0 + GELU_COEF_A*y*y), -15.0, 15.0)));
}
}
@@ -1,17 +0,0 @@
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
int z = 0;
for (uint ind = gl_LocalInvocationID.x; ind < pcs.ne00/16; ind += gl_WorkGroupSize.x) {
const uint inIndex = (r * pcs.nb01 + pcs.inAOff) + ind/NL * SIZE_OF_BLOCK;
const mat4 result = dequantize_block(inIndex, ind%NL);
for (uint j = 0; j < 4; ++j) {
for (uint k = 0; k < 4; ++k) {
const uint outIndex = i * pcs.nb1/BYTES_FOR_TYPE + pcs.outOff + z;
out_[outIndex] = result[j][k];
++z;
}
}
}
}
@@ -1,31 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
void dequantize_row_f16(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
for (int j = 0; j < k; j++) {
out_[y + j] = inA[x + j];
}
}
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
dequantize_row_f16(r*pcs.nb01/2/*bytes for float16*/ + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
}
@@ -1,31 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { float inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
for (int j = 0; j < k; j++) {
out_[y + j] = inA[x + j];
}
}
void main() {
const uint i = gl_WorkGroupID.x;
const int r = inB[i + pcs.inBOff];
dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
}
@@ -1,38 +0,0 @@
#version 450
#include "common.comp"
#define NL 2
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q4_0
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q4_0 get_unaligned_block_q4_0(uint index) {
block_q4_0 fres;
fres.d = u8BufToFloat16(inA, index);
[[unroll]] for (uint it = 0; it != QK4_0 / 2; it++) {
fres.qs[it] = inA[index+2+it];
}
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q4_0 block = get_unaligned_block_q4_0(index);
return dequantize_q4_0(block, il);
}
#include "op_getrows.comp"
@@ -1,39 +0,0 @@
#version 450
#include "common.comp"
#define NL 2
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q4_1
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q4_1 get_unaligned_block_q4_1(uint index) {
block_q4_1 fres;
fres.d = u8BufToFloat16(inA, index);
fres.m = u8BufToFloat16(inA, index+2);
[[unroll]] for (uint it = 0; it != QK4_1 / 2; it++) {
fres.qs[it] = inA[index+4+it];
}
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q4_1 block = get_unaligned_block_q4_1(index);
return dequantize_q4_1(block, il);
}
#include "op_getrows.comp"
@@ -1,44 +0,0 @@
#version 450
#include "common.comp"
#define NL 16
#define BYTES_FOR_TYPE 4 /*bytes for float*/
#define SIZE_OF_BLOCK sizeof_block_q6_k
layout(local_size_x = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { int inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb01;
int nb1;
} pcs;
block_q6_k get_unaligned_block_q6_k(uint index) {
block_q6_k fres;
[[unroll]] for (uint it = 0; it != QK_K / 2; it++) {
fres.ql[it] = inA[index + it];
}
[[unroll]] for (uint it = 0; it != QK_K / 4; it++) {
fres.qh[it] = inA[index + QK_K/2 + it];
}
[[unroll]] for (uint it = 0; it != QK_K / 16; it++) {
fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]);
}
fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16);
return fres;
}
mat4 dequantize_block(uint index, uint il) {
const block_q6_k block = get_unaligned_block_q6_k(index);
return dequantize_q6_k(block, il);
}
#include "op_getrows.comp"
@@ -1,52 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1024) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int nb00;
int nb01;
int nb02;
int nb03;
int ne10;
int ne11;
int ne12;
int ne13;
int nb10;
int nb11;
int nb12;
int nb13;
int ne0;
int nb0;
int nb1;
int nb2;
int nb3;
} pcs;
void main() {
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint i13 = i03 % pcs.ne13;
const uint i12 = i02 % pcs.ne12;
const uint i11 = i01 % pcs.ne11;
uint src0_off = uint((i03*pcs.nb03 + i02*pcs.nb02 + i01*pcs.nb01) / 4);
uint src1_off = uint((i13*pcs.nb13 + i12*pcs.nb12 + i11*pcs.nb11) / 4);
uint dst_off = uint((i03*pcs.nb3 + i02*pcs.nb2 + i01*pcs.nb1) / 4);
for (uint i0 = gl_LocalInvocationID.x; i0 < pcs.ne0; i0 += gl_WorkGroupSize.x) {
const uint i10 = i0 % pcs.ne10;
out_[pcs.outOff + dst_off + i0] = inA[pcs.inAOff + src0_off + i0] * inB[pcs.inBOff + src1_off + i10];
}
}
@@ -1,69 +0,0 @@
#version 450
#include "common.comp"
#extension GL_KHR_shader_subgroup_arithmetic : require
layout(local_size_x_id = 0) in;
layout (binding = 0) readonly buffer tensorInA { float16_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne10;
int ne11;
int ne12;
uint nb10;
uint nb11;
uint nb12;
uint nb13;
int ne0;
int ne1;
uint r2;
uint r3;
} pcs;
#define N_F16_F32 4
void main() {
const uint r0 = gl_WorkGroupID.x;
const uint rb = gl_WorkGroupID.y*N_F16_F32;
const uint im = gl_WorkGroupID.z;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
for (uint row = 0; row < N_F16_F32; ++row) {
uint r1 = rb + row;
if (r1 >= pcs.ne11) {
break;
}
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
sumf += float(inA[x+i]) * float(inB[y+i]);
}
const float all_sum = subgroupAdd(sumf);
if (subgroupElect()) {
out_[im*pcs.ne1*pcs.ne0 + r1*pcs.ne0 + r0 + pcs.outOff] = all_sum;
}
}
}
@@ -1,51 +0,0 @@
#version 450
#include "common.comp"
#extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable
// device subgroup size
layout (local_size_x_id = 0) in;
layout(binding = 0) readonly buffer tensorInA { float inA[]; };
layout(binding = 1) readonly buffer tensorInB { float inB[]; };
layout(binding = 2) writeonly buffer tensorOut { float out_[]; };
layout(push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
int ne11;
int ne12;
uint nb01;
uint nb02;
uint nb11;
uint nb12;
uint nb1;
uint nb2;
}
pcs;
void main() {
uvec3 gid = gl_WorkGroupID;
uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z;
uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z;
const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA
const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB
float sum = 0.0f;
for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
sum += float(inA[x+i]) * float(inB[y+i]);
}
const float all_sum = subgroupAdd(sum);
if (subgroupElect()) {
out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum;
}
}
@@ -1,33 +0,0 @@
#version 450
#include "common.comp"
#define BLOCKS_IN_QUANT QK4_0
#define SIZE_OF_BLOCK sizeof_block_q4_0
#define N_ROWS 4
#include "op_mul_mv_q_n_pre.comp"
// The q4_0 version of this function
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
vec2 acc = vec2(0.0, 0.0);
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
float d = float(u8BufToFloat16(inA, index));
float sumy = 0.0f;
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
const uint16_t b = u8BufToU16(inA, index + 2 + il + i);
const float yl0 = inB[yb + i];
const float yl1 = inB[yb + i + 1];
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
sumy += yl0 + yl1 + yl8 + yl9;
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
}
return d * (sumy * -8.f + acc[0] + acc[1]);
}
#include "op_mul_mv_q_n.comp"
@@ -1,35 +0,0 @@
#version 450
#include "common.comp"
#define BLOCKS_IN_QUANT QK4_1
#define SIZE_OF_BLOCK sizeof_block_q4_1
#define N_ROWS 4
#include "op_mul_mv_q_n_pre.comp"
// The q4_1 version of this function
float block_q_n_dot_y(uint block_index, uint yb, uint il) {
vec2 acc = vec2(0.0, 0.0);
const uint index = (block_index) * SIZE_OF_BLOCK + pcs.inAOff;
float d = float(u8BufToFloat16(inA, index));
float m = float(u8BufToFloat16(inA, index+2));
float sumy = 0.0f;
for (int i = 0; i < BLOCKS_IN_QUANT/4; i+=2) {
const uint16_t b = u8BufToU16(inA, index + 4 + il + i);
const float yl0 = inB[yb + i];
const float yl1 = inB[yb + i + 1];
const float yl8 = inB[yb + i + BLOCKS_IN_QUANT/2];
const float yl9 = inB[yb + i + BLOCKS_IN_QUANT/2 + 1];
sumy += yl0 + yl1 + yl8 + yl9;
acc[0] += yl0 * (b & 0x000F) + yl1 / 256.f * (b & 0x0F00);
acc[1] += yl8 / 16.f * (b & 0x00F0) + yl9 / 4096.f * (b & 0xF000);
}
return d * (acc[0] + acc[1]) + sumy * m;
}
#include "op_mul_mv_q_n.comp"
@@ -1,140 +0,0 @@
#version 450
#include "common.comp"
#define N_DST 4
#define SIZE_OF_BLOCK sizeof_block_q4_k
layout(local_size_x = 4) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
const uint16_t kmask1 = uint16_t(0x3f3f);
const uint16_t kmask2 = uint16_t(0x0f0f);
const uint16_t kmask3 = uint16_t(0xc0c0);
const uint ix = gl_SubgroupInvocationID/8; // 0...3
const uint it = gl_SubgroupInvocationID%8; // 0...7
const uint iq = it/4; // 0 or 1
const uint ir = it%4; // 0...3
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = r0 * N_DST;
const uint ib_row = first_row * nb;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13;
const uint xblk = offset0 + pcs.inAOff;
const uint y = (offset1 / 4) + pcs.inBOff;
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f};
float all_sum = 0.f;
uint y4 = y + ix * QK_K + 64 * iq + 8 * ir;
for (uint ib = ix; ib < nb; ib += 4) {
const uint blk_idx = ib + xblk;
float sumy[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0];
yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8];
yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0];
yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8];
}
for (int row = 0; row < N_DST; row++) {
uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4);
uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6);
uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8);
uint16_t sc16[4];
sc16[0] = sc_0 & kmask1;
sc16[1] = sc_2 & kmask1;
sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2);
sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2);
float acc1[4] = {0.f, 0.f, 0.f, 0.f};
float acc2[4] = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i);
uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i);
acc1[0] += yl[i+0] * (q1 & 0x000F);
acc1[1] += yl[i+1] * (q1 & 0x0F00);
acc1[2] += yl[i+8] * (q1 & 0x00F0);
acc1[3] += yl[i+9] * (q1 & 0xF000);
acc2[0] += yh[i+0] * (q2 & 0x000F);
acc2[1] += yh[i+1] * (q2 & 0x0F00);
acc2[2] += yh[i+8] * (q2 & 0x00F0);
acc2[3] += yh[i+9] * (q2 & 0xF000);
}
uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF);
uint8_t sc8_1 = uint8_t(sc16[0] >> 8 );
uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF);
uint8_t sc8_3 = uint8_t(sc16[1] >> 8 );
uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF);
uint8_t sc8_5 = uint8_t(sc16[2] >> 8 );
uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF);
uint8_t sc8_7 = uint8_t(sc16[3] >> 8 );
float dall = float(inA[blk_idx + row_idx].d);
float dmin = float(inA[blk_idx + row_idx].dmin);
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 +
(acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f +
(acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 +
(acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) -
dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7);
}
y4 += 4 * QK_K;
}
for (int row = 0; row < N_DST; ++row) {
all_sum = subgroupAdd(sumf[row]);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum;
}
}
}
@@ -1,106 +0,0 @@
#version 450
#include "common.comp"
#define SIZE_OF_BLOCK sizeof_block_q6_k
layout(local_size_x_id = 0) in;
layout(local_size_y_id = 1) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne10;
int ne0;
int ne1;
int ne01;
int ne02;
int ne12;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
void main() {
const uint8_t kmask1 = uint8_t(0x03);
const uint8_t kmask2 = uint8_t(0x0C);
const uint8_t kmask3 = uint8_t(0x30);
const uint8_t kmask4 = uint8_t(0xC0);
const uint nb = pcs.ne00/QK_K;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf = 0;
// bits of invocation ID for gl_SubgroupSize=32:
// x x x x x
// 4 3 2 1 0
// ( tid ) ix
// ip ( il )
const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes
const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0
const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1
const uint ip = tid/8; // first or second half of block (0 or 1)
const uint il = tid%8; // each half has 8 parts, one per scale
const uint n = 4; // 4 scales at a time (and 4 sums)
const uint l0 = n*il; // offset into half-block, 0..28
const uint is = 8*ip + l0/16; // 0, 1, 8, 9
const uint y_offset = 128*ip + l0;
const uint q_offset_l = 64*ip + l0;
const uint q_offset_h = 32*ip + l0;
for (uint i = ix; i < nb; i += block_stride) {
const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff;
const uint qlIndex = q_offset_l;
const uint q2Index = qlIndex + QK_K/8;
const uint qhIndex = q_offset_h;
const uint y = yy + i * QK_K + y_offset;
float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f};
for (uint l = 0; l < n; ++l) {
const uint8_t currentQ1 = inA[baseIndex + qlIndex + l];
const uint8_t currentQ2 = inA[baseIndex + q2Index + l];
const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l];
sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32);
sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32);
sums[2] += inB[y+l+64] * (int8_t((currentQ1 >> 4) | ((currentQh & kmask3) << 0)) - 32);
sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32);
}
float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16);
sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is]));
}
const float tot = subgroupAdd(sumf);
if (subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
}
}
@@ -1,73 +0,0 @@
#version 450
#include "common.comp"
#include "op_mul_mv_q_n_pre.comp"
#define SIZE_OF_D 2
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
#define NB_Q8_0 8
void main() {
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
if (gl_SubgroupInvocationID > 31)
return;
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
const int nb = pcs.ne00/QK8_0;
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = (r0 * nsg + gl_SubgroupID) * nr;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
const uint x = offset0*sizeof_block_q8_0 + pcs.inAOff; // Based from inA
const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff; // based from inB
float yl[NB_Q8_0];
float sumf[N_DST]={0.f, 0.f, 0.f, 0.f};
const uint ix = gl_SubgroupInvocationID.x/4;
const uint il = gl_SubgroupInvocationID.x%4;
uint yb = y + ix * QK8_0 + NB_Q8_0*il;
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (uint ib = ix; ib < nb; ib += nw/4) {
for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = inB[yb + i];
}
for (int row = 0; row < nr; row++) {
const uint block_offset = (ib+row*nb) * sizeof_block_q8_0;
float sumq = 0.f;
for (int iq = 0; iq < NB_Q8_0; ++iq) {
const int8_t qs_iq = int8_t(inA[x + block_offset + SIZE_OF_D + NB_Q8_0*il + iq]);
sumq += qs_iq * yl[iq];
}
const float16_t d = u8BufToFloat16(inA, x + block_offset);
sumf[row] += sumq*d;
}
yb += NB_Q8_0 * nw;
}
for (int row = 0; row < nr; ++row) {
const float tot = subgroupAdd(sumf[row]);
if (subgroupElect() && first_row + row < pcs.ne01) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row] = tot;
}
}
}
@@ -1,52 +0,0 @@
void main() {
// NB: hack to make compatible with AMD GPUs that have a subgroup size of 64
if (gl_SubgroupInvocationID > 31)
return;
const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT);
const uint r0 = gl_WorkGroupID.x;
const uint r1 = gl_WorkGroupID.y;
const uint im = gl_WorkGroupID.z;
const uint first_row = (r0 * gl_NumSubgroups + gl_SubgroupID) * N_ROWS;
const uint i12 = im%pcs.ne12;
const uint i13 = im/pcs.ne12;
// pointers to src0 rows
uint ax[N_ROWS];
for (int row = 0; row < N_ROWS; ++row) {
const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
ax[row] = offset0 + pcs.inAOff;
}
const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
const uint ix = gl_SubgroupInvocationID/2;
const uint il = (BLOCKS_IN_QUANT/4)*(gl_SubgroupInvocationID%2);
uint yb = y + ix * BLOCKS_IN_QUANT + il;
//debugPrintfEXT("gl_NumSubgroups=%d, gl_SubgroupID=%d, gl_SubgroupInvocationID=%d, glSubgroupSize=%d, gl_WorkGroupSize.x=%d, gl_WorkGroupSize.y=%d, gl_WorkGroupSize.z=%d\n",
// gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize,
// gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z);
for (uint ib = ix; ib < nb; ib += 16) {
for (int row = 0; row < N_ROWS; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
}
yb += BLOCKS_IN_QUANT * 16;
}
for (int row = 0; row < N_ROWS; ++row) {
const float tot = subgroupAdd(sumf[row]);
if (first_row + row < pcs.ne01 && subgroupElect()) {
out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = tot;
}
}
}
@@ -1,28 +0,0 @@
layout(local_size_x_id = 0) in;
layout(local_size_y = 8) in;
layout(local_size_z = 1) in;
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout (binding = 1) readonly buffer tensorInB { float inB[]; };
layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
int ne10;
int ne12;
int ne0;
int ne1;
uint nb01;
uint nb02;
uint nb03;
uint nb11;
uint nb12;
uint nb13;
uint r2;
uint r3;
} pcs;
@@ -1,84 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 256) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint ne00;
uint nb01;
float eps;
} pcs;
shared float sum[gl_WorkGroupSize.x];
void main() {
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
// MEAN
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += in_[x+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float mean = sum[0];
// recenter
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] = in_[x+i00] - mean;
}
// VARIANCE
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += out_[y+i00] * out_[y+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + pcs.eps);
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] *= scale;
}
}
@@ -1,21 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = max(0.0, in_[i + pcs.inOff]);
}
}
@@ -1,53 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 512) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
uint ne00;
uint nb01;
float eps;
} pcs;
shared float sum[gl_WorkGroupSize.x];
void main() {
const uint x = (gl_WorkGroupID.x*pcs.nb01/4) + pcs.inOff; // Based from in_
// parallel sum
sum[gl_LocalInvocationID.x] = 0.0;
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
sum[gl_LocalInvocationID.x] += in_[x+i00] * in_[x+i00];
}
// reduce
barrier();
memoryBarrierShared();
[[unroll]] for (uint i = gl_WorkGroupSize.x/2; i > 0; i /= 2) {
if (gl_LocalInvocationID.x < i) {
sum[gl_LocalInvocationID.x] += sum[gl_LocalInvocationID.x + i];
}
barrier();
memoryBarrierShared();
}
// broadcast
if (gl_LocalInvocationID.x == 0) {
sum[0] /= float(pcs.ne00);
}
barrier();
memoryBarrierShared();
const float scale = 1.0f/sqrt(sum[0] + pcs.eps);
const uint y = (gl_WorkGroupID.x*pcs.ne00) + pcs.outOff; // Based from out_
for (uint i00 = gl_LocalInvocationID.x; i00 < pcs.ne00; i00 += gl_WorkGroupSize.x) {
out_[y+i00] = in_[x+i00] * scale;
}
}
@@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+pcs.n_dims/2]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}
@@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+pcs.n_dims/2];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}
@@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
const float x0 = float(inA[src]);
const float x1 = float(inA[src+1]);
out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta);
out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}
@@ -1,52 +0,0 @@
#version 450
#include "rope_common.comp"
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; };
layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;
float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
float theta_base = float(inB[pcs.inBOff + i2]);
float inv_ndims = -1.f/pcs.n_dims;
float cos_theta;
float sin_theta;
for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
if (i0 < pcs.n_dims) {
uint ic = i0/2;
float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
const float x0 = inA[src];
const float x1 = inA[src+1];
out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
} else {
const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
out_[dst_data] = inA[src];
out_[dst_data+1] = inA[src+1];
}
}
}
@@ -1,19 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;
void main() {
const uint i = gl_WorkGroupID.x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
@@ -1,23 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
float scale;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 8;
for (uint x = 0; x < 8; x++) {
const uint i = baseIndex + x;
out_[i + pcs.outOff] = in_[i + pcs.inOff] * pcs.scale;
}
}
@@ -1,22 +0,0 @@
#version 450
#include "common.comp"
layout(local_size_x = 1) in;
layout(binding = 0) buffer restrict readonly tensorIn { float in_[]; };
layout(binding = 1) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inOff;
uint outOff;
} pcs;
void main() {
const uint baseIndex = gl_WorkGroupID.x * 4;
for (uint x = 0; x < 4; x++) {
const uint i = baseIndex + x;
const float y = in_[i + pcs.inOff];
out_[i + pcs.outOff] = y / (1.0 + exp(-y));
}
}
@@ -1,72 +0,0 @@
// TODO: implement multi-simd softmax (llama.cpp commit e16b9fa4)
#version 450
#include "common.comp"
layout(local_size_x_id = 0) in;
layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { float inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inBOff;
uint outOff;
int ne00;
int ne01;
int ne02;
float scale;
float max_bias;
float m0;
float m1;
uint n_head_log2;
int mask;
} pcs;
void main() {
if (gl_SubgroupInvocationID > 31)
return;
const uint i03 = gl_WorkGroupID.z;
const uint i02 = gl_WorkGroupID.y;
const uint i01 = gl_WorkGroupID.x;
const uint extra_off = i03*pcs.ne02*pcs.ne01*pcs.ne00 + i02*pcs.ne01*pcs.ne00 + i01*pcs.ne00;
const uint psrc0 = extra_off + pcs.inAOff; // Based from inA
const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
const uint pdst = extra_off + pcs.outOff; // Based from out_
float slope = 1.0f;
// ALiBi
if (pcs.max_bias > 0.0f) {
int64_t h = i02;
float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
slope = pow(base, float(exp));
}
// parallel max
float localMax = uintBitsToFloat(0xFF800000);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
}
float max_ = subgroupMax(localMax);
// parallel sum
float localSum = 0.0f;
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
localSum += exp_psrc0;
out_[pdst + i00] = exp_psrc0;
}
const float sum = subgroupAdd(localSum);
for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
out_[pdst + i00] /= sum;
}
}
@@ -1,71 +0,0 @@
#include "common.comp"
#define GGML_ROPE_TYPE_NEOX 2
// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint inCOff;
uint outOff;
int n_dims;
int mode;
int n_ctx_orig;
float freq_base;
float freq_scale;
bool has_freq_factors;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;
float rope_yarn_ramp(const float low, const float high, const float i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
out float cos_theta, out float sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
}
void rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
}
+7 -6
View File
@@ -230,8 +230,10 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
float scale;
@@ -513,26 +515,25 @@ typedef struct {
typedef struct {
int64_t d_state;
int64_t d_inner;
int64_t n_head;
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb10;
uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb30;
uint64_t nb31;
uint64_t nb40;
uint64_t nb41;
uint64_t nb42;
uint64_t nb50;
uint64_t nb43;
uint64_t nb51;
uint64_t nb52;
uint64_t nb53;
} ggml_metal_kargs_ssm_scan;
typedef struct {
+75 -32
View File
@@ -217,6 +217,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -529,6 +530,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1196,6 +1199,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
@@ -1508,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@@ -1691,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
@@ -2454,6 +2462,12 @@ static bool ggml_metal_encode_node(
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
case GGML_GLU_OP_GEGLU_QUICK:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
@@ -2809,71 +2823,91 @@ static bool ggml_metal_encode_node(
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
struct ggml_tensor * src6 = node->src[6];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
GGML_ASSERT(src5);
GGML_ASSERT(src6);
size_t offs_src3 = 0;
size_t offs_src4 = 0;
size_t offs_src5 = 0;
size_t offs_src6 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
id<MTLBuffer> id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil;
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
const int64_t ne30 = src3->ne[0];
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
const uint64_t nb30 = src3->nb[0];
const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30);
const uint64_t nb31 = src3->nb[1];
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
const int64_t ne41 = src4->ne[1];
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43);
const uint64_t nb40 = src4->nb[0];
const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40);
const uint64_t nb41 = src4->nb[1];
const uint64_t nb42 = src4->nb[2];
const uint64_t nb43 = src4->nb[3];
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53);
const uint64_t nb50 = src5->nb[0];
const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50);
const uint64_t nb51 = src5->nb[1];
const uint64_t nb52 = src5->nb[2];
const uint64_t nb53 = src5->nb[3];
const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60);
const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60);
const int64_t d_state = ne00;
const int64_t d_inner = ne01;
const int64_t n_seq_tokens = ne11;
const int64_t n_seqs = ne02;
const int64_t n_head = ne02;
const int64_t n_group = ne41;
const int64_t n_seq_tokens = ne12;
const int64_t n_seqs = ne13;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
id<MTLComputePipelineState> pipeline = nil;
if (ne30 == 1) {
// Mamba-2
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
}
ggml_metal_kargs_ssm_scan args = {
/*.d_state =*/ d_state,
/*.d_inner =*/ d_inner,
/*.d_state =*/ d_state,
/*.d_inner =*/ d_inner,
/*.n_head =*/ n_head,
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb30 =*/ nb30,
/*.nb31 =*/ nb31,
/*.nb40 =*/ nb40,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.nb50 =*/ nb50,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.n_seqs =*/ n_seqs,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb31 =*/ nb31,
/*.nb41 =*/ nb41,
/*.nb42 =*/ nb42,
/*.nb43 =*/ nb43,
/*.nb51 =*/ nb51,
/*.nb52 =*/ nb52,
/*.nb53 =*/ nb53,
};
[encoder setComputePipelineState:pipeline];
@@ -2883,10 +2917,17 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
[encoder setBytes:&args length:sizeof(args) atIndex:7];
[encoder setBuffer:id_src6 offset:offs_src6 atIndex:6];
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break;
case GGML_OP_RWKV_WKV6:
{
@@ -4989,8 +5030,10 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
+144 -24
View File
@@ -109,6 +109,7 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
}
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
@@ -167,6 +168,7 @@ void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
}
void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
float max = 0.0f;
@@ -461,6 +463,7 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
}
void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
@@ -1258,6 +1261,50 @@ kernel void kernel_swiglu(
}
}
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_quick(
device const char * src0,
device const char * src1,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
template <bool norm>
kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args,
@@ -1596,7 +1643,7 @@ kernel void kernel_ssm_conv_f32(
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
kernel void kernel_ssm_scan_f32(
device const void * src0,
device const void * src1,
@@ -1604,46 +1651,119 @@ kernel void kernel_ssm_scan_f32(
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i3 = tgpig.y;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
// const int64_t nr = args.d_inner;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
// const int64_t n_s = args.n_seqs;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13);
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
if (i2 > 0) {
s0 = s;
}
// i1 == 0
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
float x_dt = x[0] * dt_soft_plus;
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
int64_t i = i0;
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
// recurse
s0 = s;
}
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
// TODO: optimize (e.g. by parallelizing over d_state)
kernel void kernel_ssm_scan_f32_group(
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
constant ggml_metal_kargs_ssm_scan & args,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc; ++i0) {
const int64_t i = i0 + i1*nc;
const float state = (s0[i] * dA) + (B[i0] * x_dt);
sumf += state * C[i0];
s[i] = state;
}
y[0] = sumf;
// recurse
s0 = s;
}
}
@@ -3784,7 +3904,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
const float m = pm[ic + tiisg];
@@ -4270,7 +4390,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f;
+118 -23
View File
@@ -398,12 +398,13 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_scale;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
cl_kernel kernel_relu;
cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
cl_kernel kernel_clamp;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_rms_norm;
cl_kernel kernel_group_norm;
@@ -736,6 +737,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_erf = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_erf_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_erf_4", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err));
GGML_LOG_CONT(".");
@@ -753,12 +756,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_glu =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
CL_CHECK((backend_ctx->kernel_reglu = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_reglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_erf_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
GGML_LOG_CONT(".");
}
@@ -2222,6 +2229,12 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
default:
return false;
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:
@@ -2256,6 +2269,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_UNARY_OP_SIGMOID:
@@ -2271,6 +2285,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
default:
return false;
@@ -3858,6 +3874,44 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
int n = ggml_nelements(dst);
if (n % 4 == 0) {
kernel = backend_ctx->kernel_gelu_erf_4;
n /= 4;
} else {
kernel = backend_ctx->kernel_gelu_erf;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@@ -5757,19 +5811,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
const int ne00 = src0 ? src0->ne[0] : 0;
const int ne01 = src0 ? src0->ne[1] : 0;
const int ne02 = src0 ? src0->ne[2] : 0;
const int ne03 = src0 ? src0->ne[3] : 0;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_long nb01 = src0->nb[1];
const cl_long nb02 = src0->nb[2];
const cl_long nb03 = src0->nb[3];
const int ne12 = src1 ? src1->ne[2] : 0;
const int ne13 = src1 ? src1->ne[3] : 0;
const cl_long nb11 = src1 ? src1->nb[1] : 0;
const cl_long nb12 = src1 ? src1->nb[2] : 0;
const cl_long nb13 = src1 ? src1->nb[3] : 0;
const cl_long nb1 = dst->nb[1];
const cl_long nb2 = dst->nb[2];
const cl_long nb3 = dst->nb[3];
float scale, max_bias;
memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
const int nrows_x = ggml_nrows(src0);
const int nrows_y = src0->ne[1];
const int n_head = nrows_x/nrows_y;
const int n_head = src0->ne[2];
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -5814,13 +5880,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
@@ -6227,6 +6302,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_swiglu_f16;
}
break;
case GGML_GLU_OP_GEGLU_ERF:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_erf;
} else {
kernel = backend_ctx->kernel_geglu_erf_f16;
}
break;
case GGML_GLU_OP_GEGLU_QUICK:
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_geglu_quick;
} else {
kernel = backend_ctx->kernel_geglu_quick_f16;
}
break;
default:
GGML_ABORT("Unsupported glu op");
}
@@ -6341,6 +6430,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_gelu;
break;
case GGML_UNARY_OP_GELU_ERF:
if (!any_on_device) {
return false;
}
func = ggml_cl_gelu_erf;
break;
case GGML_UNARY_OP_GELU_QUICK:
if (!any_on_device) {
return false;
+27
View File
@@ -6,6 +6,7 @@
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
#define SQRT_2_INV 0.70710678118654752440084436210484f
kernel void kernel_gelu(
global float * src0,
@@ -35,6 +36,32 @@ kernel void kernel_gelu_4(
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
}
kernel void kernel_gelu_erf(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
float x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
}
kernel void kernel_gelu_erf_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
float4 x = src0[get_global_id(0)];
dst[get_global_id(0)] = 0.5f*x*(1.0f + erf(x*SQRT_2_INV));
}
kernel void kernel_gelu_quick(
global float * src0,
ulong offset0,
+136
View File
@@ -1,7 +1,9 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define GELU_COEF_A 0.044715f
#define GELU_QUICK_COEF -1.702f
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
#define SQRT_2_INV 0.70710678118654752440084436210484f
//------------------------------------------------------------------------------
// geglu
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
dst_row[i0] = silu*x1;
}
}
//------------------------------------------------------------------------------
// geglu_erf
//------------------------------------------------------------------------------
kernel void kernel_geglu_erf(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
kernel void kernel_geglu_erf_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
dst_row[i0] = gelu_erf*x1;
}
}
//------------------------------------------------------------------------------
// geglu_quick
//------------------------------------------------------------------------------
kernel void kernel_geglu_quick(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global float * dst_row = (global float *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const float x0 = src0_row[i0];
const float x1 = src1_row[i0];
const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
kernel void kernel_geglu_quick_f16(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb01,
ulong nb11,
int ne0,
ulong nb1,
int ne00_off,
int ne10_off
) {
src0 = (global char*)((global char*)src0 + offset0);
src1 = (global char*)((global char*)src1 + offset1);
dst = (global char*)((global char*)dst + offsetd);
global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
global half * dst_row = (global half *) ((global char *) dst + get_group_id(0)*nb1);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
const half x0 = src0_row[i0];
const half x1 = src1_row[i0];
const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
dst_row[i0] = gelu_quick*x1;
}
}
+24 -11
View File
@@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4_f16(
global float * src0,
global char * src0,
ulong offset0,
global half * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
+24 -11
View File
@@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_4(
global float * src0,
global char * src0,
ulong offset0,
global float * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
+24 -11
View File
@@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max_f16(
global float * src0,
global char * src0,
ulong offset0,
global half * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float *)((global char *)src0 + offset0);
src1 = (global half *)((global char *)src1 + offset1);
dst = (global float *)((global char *)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
+24 -11
View File
@@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_soft_max(
global float * src0,
global char * src0,
ulong offset0,
global float * src1,
global char * src1,
ulong offset1,
global float * dst,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale,
float max_bias,
float m0,
float m1,
int n_head_log2
) {
src0 = (global float*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
int i13 = i03%ne13;
int i12 = i02%ne12;
int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f;
+50
View File
@@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
}
}
template<typename T>
static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
dst[i] = op_gelu_erf(x[j0]) * g[j1];
}
}
template<typename T>
static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
dst[i] = op_gelu_quick(x[j0]) * g[j1];
}
}
namespace ggml_sycl_detail {
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
@@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
});
}
static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
sycl_parallel_for(main_stream,
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
}
static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
sycl_parallel_for(main_stream,
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
});
});
}
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_swiglu(ctx, dst);
}
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_erf(ctx, dst);
}
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_quick(ctx, dst);
}
+2
View File
@@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ELEMENTWISE_HPP
+15 -1
View File
@@ -83,7 +83,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
info.devices[i].opt_feature.reorder = !device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
}
@@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_GLU_OP_SWIGLU:
ggml_sycl_swiglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_sycl_geglu_erf(ctx, dst);
break;
case GGML_GLU_OP_GEGLU_QUICK:
ggml_sycl_geglu_quick(ctx, dst);
break;
default:
return false;
}
@@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;
@@ -4285,6 +4293,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return false;
}
}
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CPY:
{
ggml_type src0_type = op->src[0]->type;
+242 -127
View File
@@ -224,6 +224,21 @@ enum vk_device_architecture {
INTEL_XE2,
};
// HSK x HSV
enum FaHeadSizes {
FA_HEAD_SIZE_64,
FA_HEAD_SIZE_80,
FA_HEAD_SIZE_96,
FA_HEAD_SIZE_112,
FA_HEAD_SIZE_128,
FA_HEAD_SIZE_192,
FA_HEAD_SIZE_192_128,
FA_HEAD_SIZE_256,
FA_HEAD_SIZE_576_512,
FA_HEAD_SIZE_UNSUPPORTED,
FA_HEAD_SIZE_COUNT = FA_HEAD_SIZE_UNSUPPORTED,
};
static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
vk::PhysicalDeviceProperties props = device.getProperties();
@@ -441,6 +456,8 @@ struct vk_device_struct {
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
vk_pipeline pipeline_leaky_relu_f32;
vk_pipeline pipeline_silu_back_f32;
@@ -467,26 +484,11 @@ struct vk_device_struct {
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_cm1[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
vk_pipeline pipeline_flash_attn_split_k_reduce;
@@ -499,6 +501,8 @@ struct vk_device_struct {
ggml_backend_buffer_type buffer_type;
bool disable_fusion;
#ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger;
#endif
@@ -634,6 +638,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@@ -649,8 +654,7 @@ struct vk_flash_attn_push_constants {
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;
@@ -1003,7 +1007,7 @@ struct ggml_backend_vk_context {
// number of additional consecutive nodes that are being fused with the
// node currently being processed
uint32_t num_additional_fused_ops {};
int num_additional_fused_ops {};
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
@@ -1089,8 +1093,8 @@ static size_t vk_skip_checks;
static size_t vk_output_tensor;
static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
static void ggml_vk_check_results_0(ggml_tensor * tensor);
static void ggml_vk_check_results_1(ggml_tensor * tensor);
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
#endif
typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -1699,6 +1703,35 @@ enum FaCodePath {
FA_COOPMAT2,
};
static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
if (hsk != 192 && hsk != 576 && hsk != hsv) {
return FA_HEAD_SIZE_UNSUPPORTED;
}
switch (hsk) {
case 64: return FA_HEAD_SIZE_64;
case 80: return FA_HEAD_SIZE_80;
case 96: return FA_HEAD_SIZE_96;
case 112: return FA_HEAD_SIZE_112;
case 128: return FA_HEAD_SIZE_128;
case 192:
if (hsv == 192) {
return FA_HEAD_SIZE_192;
} else if (hsv == 128) {
return FA_HEAD_SIZE_192_128;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
case 256: return FA_HEAD_SIZE_256;
case 576:
if (hsv == 512) {
return FA_HEAD_SIZE_576_512;
} else {
return FA_HEAD_SIZE_UNSUPPORTED;
}
default: return FA_HEAD_SIZE_UNSUPPORTED;
}
}
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
@@ -1719,8 +1752,9 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
}
}
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) {
GGML_UNUSED(clamp);
GGML_UNUSED(hsv);
if (path == FA_SCALAR) {
if (small_rows) {
@@ -1744,7 +1778,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_
}
// small cols to reduce register count
if (ggml_is_quantized(type) || D == 256) {
if (ggml_is_quantized(type) || hsk >= 256) {
return {64, 32};
}
return {64, 64};
@@ -2037,19 +2071,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
// For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@@ -2058,26 +2094,29 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
};
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, HSK, HSV, HEAD_SIZES) \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][0][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,false), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,false), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][0][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][0], "flash_attn_f32_f16_" #HEAD_SIZES "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16 ## SUFFIX[TYPE][FA_HEAD_SIZE_##HEAD_SIZES][1][1][1], "flash_attn_f32_f16_" #HEAD_SIZES "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,true), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,true), fa_rows_cols(FAPATH,HSK,HSV,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64, 64, 64) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80, 80, 80) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96, 96, 96) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112, 112, 112) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128, 128, 128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 192, 192) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 192, 128, 192_128) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256, 256, 256) \
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 576, 512, 576_512)
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
@@ -2786,6 +2825,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -3468,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->idx = idx;
device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
return device;
}
@@ -3688,7 +3731,6 @@ static void ggml_vk_instance_init() {
}
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
@@ -6002,24 +6044,47 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = scalar_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
const uint32_t masksh = Bc * Br * sizeof(float);
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = coopmat1_flash_attention_num_large_rows;
const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t acctype = f32acc ? 4 : 2;
const uint32_t f16vec4 = 8;
const uint32_t tmpsh = wg_size * sizeof(float);
const uint32_t tmpshv4 = wg_size * 4 * acctype;
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
const uint32_t Qf = Br * (hsk / 4 + 2) * f16vec4;
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
const uint32_t sfsh = Bc * sfshstride * acctype;
const uint32_t kshstride = D / 4 + 2;
const uint32_t kshstride = hsk / 4 + 2;
const uint32_t ksh = Bc * kshstride * f16vec4;
const uint32_t slope = Br * sizeof(float);
@@ -6027,7 +6092,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
return supported;
}
@@ -6050,12 +6115,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
const uint32_t nem3 = mask ? mask->ne[3] : 0;
const uint32_t D = neq0;
const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
uint32_t N = neq1;
const uint32_t KV = nek1;
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne0 == HSV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
@@ -6063,12 +6130,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq0 == HSK);
GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(nev1 == nek1);
@@ -6089,7 +6153,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
path = FA_SCALAR;
@@ -6119,7 +6183,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
@@ -6142,47 +6206,25 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
path = FA_SCALAR;
}
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR &&
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) {
small_rows = true;
}
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
FaHeadSizes head_sizes = fa_get_head_sizes(k->ne[0], v->ne[0]);
switch (path) {
case FA_SCALAR:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
pipelines = &ctx->device->pipeline_flash_attn_f32_f16[k->type][head_sizes][f32acc][small_rows][0];
break;
case FA_COOPMAT1:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm1[k->type][head_sizes][f32acc][small_rows][0];
break;
case FA_COOPMAT2:
switch (D) {
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
default:
GGML_ASSERT(!"unsupported D value");
return;
}
pipelines = &ctx->device->pipeline_flash_attn_f32_f16_cm2[k->type][head_sizes][f32acc][small_rows][0];
break;
default:
GGML_ASSERT(0);
@@ -6212,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
@@ -6224,7 +6266,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large");
}
@@ -6311,17 +6353,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1, nem2,
nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1,
mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx);
@@ -6342,7 +6386,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
const std::array<uint32_t, 4> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
@@ -6542,6 +6586,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
default:
break;
}
@@ -7610,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8841,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
}
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -8886,6 +8933,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
default:
return false;
@@ -9100,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
// fused rms_norm + mul
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
} else {
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
}
break;
case GGML_OP_RMS_NORM_BACK:
@@ -9133,6 +9182,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break;
default:
@@ -9260,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->compute_ctx.reset();
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
if (!ok) {
if (node->op == GGML_OP_UNARY) {
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9275,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
return true;
}
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
GGML_UNUSED(cgraph);
ggml_backend_buffer * buf = nullptr;
switch (tensor->op) {
@@ -9351,6 +9403,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
break;
default:
@@ -9383,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
// Only run if ctx hasn't been submitted yet
if (!subctx->seqs.empty()) {
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_0(tensor);
ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
use_fence = true;
#endif
@@ -9403,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
ggml_vk_wait_for_fence(ctx);
}
#ifdef GGML_VULKAN_CHECK_RESULTS
ggml_vk_check_results_1(tensor);
ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
#endif
}
@@ -9850,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
}
static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
// additional constraints specific to this fusion
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
mul->src[0]->ne[1] != rms_norm->ne[1]) {
return false;
}
// rms_norm shader assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9863,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -9933,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
@@ -10161,6 +10246,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -10241,19 +10328,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2;
switch (op->src[0]->ne[0]) {
case 64:
case 80:
case 96:
case 112:
case 128:
case 256:
break;
default:
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
FaHeadSizes head_sizes = fa_get_head_sizes(op->src[1]->ne[0], op->src[2]->ne[0]);
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32) {
@@ -10333,6 +10409,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false;
}
} break;
case GGML_OP_SET_ROWS:
{
// TODO: add support
// ref: https://github.com/ggml-org/llama.cpp/pull/14274
return false;
} break;
case GGML_OP_CONT:
case GGML_OP_CPY:
case GGML_OP_DUP:
@@ -10713,11 +10795,21 @@ void * comp_result;
size_t comp_size;
size_t comp_nb[GGML_MAX_DIMS];
size_t check_counter = 0;
static void ggml_vk_check_results_0(ggml_tensor * tensor) {
static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
bool fused_rms_norm_mul = false;
int rms_norm_idx = -1;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
check_counter++;
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
@@ -10745,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
for (int i = 0; i < 6; i++) {
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
switch (i) {
case 0: srci = rms_norm->src[0]; break;
case 1: srci = tensor->src[1 - rms_norm_idx]; break;
default: continue;
}
}
if (srci == nullptr) {
continue;
}
@@ -10802,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_SUB) {
tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_MUL) {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
if (fused_rms_norm_mul) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
} else {
tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
}
} else if (tensor->op == GGML_OP_DIV) {
tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_CONCAT) {
@@ -10993,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
GGML_ABORT("fatal error");
}
ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph, tensor_clone);
ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
ggml_build_forward_expand(cgraph_cpu, tensor_clone);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11019,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
}
static void ggml_vk_check_results_1(ggml_tensor * tensor) {
static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
ggml_tensor * tensor = cgraph->nodes[tensor_idx];
if (tensor->op == GGML_OP_TRANSPOSE) {
return;
}
bool fused_rms_norm_mul = false;
if (ctx->num_additional_fused_ops == 1 &&
tensor->op == GGML_OP_RMS_NORM &&
cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
fused_rms_norm_mul = true;
tensor = cgraph->nodes[tensor_idx + 1];
}
if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
return;
}
@@ -11,7 +11,8 @@
#include "types.comp"
#include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split;
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -29,7 +30,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * D + c;
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
@@ -38,7 +39,7 @@ shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4];
shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
@@ -53,18 +54,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t r = (idx + tid) / (D / 4);
if (r < Br && d < D / 4 &&
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
vec4 Of[Br][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
@@ -100,8 +101,8 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@@ -116,7 +117,7 @@ void main() {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -148,7 +149,7 @@ void main() {
}
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
@@ -195,14 +196,14 @@ void main() {
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -259,7 +260,7 @@ void main() {
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
@@ -281,11 +282,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -293,7 +294,7 @@ void main() {
}
}
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -309,18 +310,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
@@ -330,9 +331,9 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
@@ -4,10 +4,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 4) const uint32_t Clamp = 0;
layout (constant_id = 5) const uint32_t D_split = 16;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (push_constant) uniform parameter {
uint32_t N;
@@ -25,6 +25,7 @@ layout (push_constant) uniform parameter {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
@@ -40,8 +41,7 @@ layout (push_constant) uniform parameter {
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
uint32_t mask_n_head_log2;
float m0;
float m1;
@@ -50,6 +50,9 @@ layout (push_constant) uniform parameter {
uint32_t k_num;
} p;
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
@@ -100,8 +103,10 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
@@ -13,7 +13,9 @@
#include "types.comp"
#include "flash_attn_base.comp"
const uint32_t D_per_thread = D / D_split;
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
@@ -32,7 +34,7 @@ layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * D + c;
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
@@ -44,14 +46,14 @@ const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = D / 4 + 2; // in units of f16vec4
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for D==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = D / 4 + 2; // in units of f16vec4
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
@@ -74,18 +76,18 @@ void main() {
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t r = (idx + tid) / (D / 4);
if (r < Br && d < D / 4 &&
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
@@ -124,17 +126,17 @@ void main() {
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t c = (idx + tid) / (D / 4);
if (c < Bc && d < D / 4) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
@@ -149,14 +151,14 @@ void main() {
}
barrier();
// K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < D / 16; ++d) {
for (uint32_t d = 0; d < HSK / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
@@ -180,7 +182,7 @@ void main() {
barrier();
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
@@ -206,7 +208,7 @@ void main() {
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
}
@@ -221,7 +223,7 @@ void main() {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
@@ -284,7 +286,7 @@ void main() {
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
@@ -304,11 +306,11 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
@@ -316,7 +318,7 @@ void main() {
}
}
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -332,18 +334,18 @@ void main() {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
@@ -353,9 +355,9 @@ void main() {
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
@@ -61,8 +61,8 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < D) {
uint32_t offset = (iq2 + r) * D + c;
if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
@@ -86,9 +86,9 @@ void main() {
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
@@ -104,16 +104,16 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
@@ -131,8 +131,8 @@ void main() {
}
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
@@ -140,10 +140,10 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
@@ -153,7 +153,7 @@ void main() {
}
}
if (p.mask != 0) {
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
@@ -208,42 +208,42 @@ void main() {
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -255,18 +255,18 @@ void main() {
O = Ldiag*O;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
}
}
@@ -0,0 +1,27 @@
#version 450
#include "glu_head.comp"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.comp"
@@ -0,0 +1,11 @@
#version 450
#include "glu_head.comp"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.comp"
+77 -67
View File
@@ -500,10 +500,9 @@ void main() {
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx % 128) / 4;
const int i8 = 2 * int(idx % 4);
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 32;
const float d = float(data_a[ib].d);
const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +511,16 @@ void main() {
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
[[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ1_M)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4;
const uint ib = idx / 32; // 8 values per idx
const uint ib8 = idx % 32;
const uint ib16 = ib8 / 2;
const int i8 = 2 * int(idx % 4);
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +531,17 @@ void main() {
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
const vec2 v = dl * (vec2(gvec) + delta);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
[[unroll]] for (int k = 0; k < 8; ++k) {
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
}
#elif defined(DATA_A_IQ2_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4;
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 4;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +551,81 @@ void main() {
data_a[ib].qs[8*ib32 + 6],
data_a[ib].qs[8*ib32 + 7]
));
const float db = d * 0.25 * (0.5 + (signs >> 28));
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = sign7 | (bitCount(sign7) << 7);
const uvec2 grid = iq2xxs_grid[qs];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint ib8 = (idx / 4) % 4; // 0..3
const uint ib = idx / 32; // 8 values per idx
const uint ib32 = (idx % 32) / 4; // 0..7
const uint ib8 = idx % 4; // 0..3
const float d = float(data_a[ib].d);
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const float db = d * 0.25 * (0.5 + scale);
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
const uint sign7 = qs >> 9;
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = sign7 | (bitCount(sign7) << 7);
const uvec2 grid = iq2xs_grid[qs & 511];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ2_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib8 = (idx % 128) / 4; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint ib = idx / 32; // 8 values per idx
const uint ib8 = idx % 32; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint qs = data_a[ib].qs[ib8];
const uint qh = data_a[ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
const float d = float(data_a[ib].d);
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
#elif defined(DATA_A_IQ3_XXS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint ib = idx / 64; // 4 values per idx
const uint iqs = idx % 64; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const float d = float(data_a[ib].d);
@@ -631,33 +638,36 @@ void main() {
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
const uint grid = iq3xxs_grid[qs];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ3_S)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 128) / 2; // 0..63
const uint ib = idx / 64; // 4 values per idx
const uint iqs = idx % 64; // 0..63
const uint iqh = iqs / 8;
const float d = float(data_a[ib].d);
const uint qs = data_a[ib].qs[iqs];
const uint qh = data_a[ib].qh[iqh];
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
const uint scale = data_a[ib].scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
const vec4 v = db * vec4(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1"))
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
load_vec_quant = "4";
if (tname == "bf16") {
@@ -593,6 +593,10 @@ void process_shaders() {
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
+94 -23
View File
@@ -1140,9 +1140,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
"GEGLU",
"SWIGLU",
"GEGLU_ERF",
"GEGLU_QUICK",
};
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2768,6 +2770,48 @@ struct ggml_tensor * ggml_swiglu_split(
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
}
// ggml_geglu_erf
struct ggml_tensor * ggml_geglu_erf(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
}
struct ggml_tensor * ggml_geglu_erf_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
}
struct ggml_tensor * ggml_geglu_erf_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
}
// ggml_geglu_quick
struct ggml_tensor * ggml_geglu_quick(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
}
struct ggml_tensor * ggml_geglu_quick_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
}
struct ggml_tensor * ggml_geglu_quick_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
}
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(
@@ -3674,7 +3718,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
@@ -4704,12 +4747,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
}
if (max_bias > 0.0f) {
@@ -4837,7 +4880,6 @@ struct ggml_tensor * ggml_ssm_conv(
const int64_t n_s = sx->ne[2];
// TODO: maybe support other strides than 1?
// FIXME: this is always true?
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
GGML_ASSERT(sx->ne[1] == d_inner);
GGML_ASSERT(n_t >= 0);
@@ -4860,36 +4902,49 @@ struct ggml_tensor * ggml_ssm_scan(
struct ggml_tensor * dt,
struct ggml_tensor * A,
struct ggml_tensor * B,
struct ggml_tensor * C) {
struct ggml_tensor * C,
struct ggml_tensor * ids) {
GGML_ASSERT(ggml_is_contiguous(s));
GGML_ASSERT(ggml_is_contiguous(x));
GGML_ASSERT(ggml_is_contiguous(dt));
GGML_ASSERT(ggml_is_contiguous(A));
GGML_ASSERT(ggml_is_matrix(A));
GGML_ASSERT(ggml_is_3d(B));
GGML_ASSERT(ggml_is_3d(s));
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
GGML_ASSERT(ggml_are_same_shape(x, dt));
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
GGML_ASSERT(ggml_are_same_shape(B, C));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
{
const int64_t d_state = s->ne[0];
const int64_t d_inner = s->ne[1];
const int64_t n_seq_tokens = x->ne[1];
const int64_t n_seqs = x->ne[2];
const int64_t head_dim = x->ne[0];
const int64_t n_head = x->ne[1];
const int64_t n_seq_tokens = x->ne[2];
const int64_t n_seqs = x->ne[3];
GGML_ASSERT(s->ne[2] == n_seqs);
GGML_ASSERT(x->ne[0] == d_inner);
GGML_ASSERT(A->ne[0] == d_state);
GGML_ASSERT(A->ne[1] == d_inner);
GGML_ASSERT(dt->ne[0] == n_head);
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
GGML_ASSERT(dt->ne[2] == n_seqs);
GGML_ASSERT(ggml_is_3d(dt));
GGML_ASSERT(s->ne[1] == head_dim);
GGML_ASSERT(s->ne[2] == n_head);
GGML_ASSERT(B->ne[0] == d_state);
GGML_ASSERT(B->ne[1] == n_seq_tokens);
GGML_ASSERT(B->ne[2] == n_seqs);
GGML_ASSERT(B->ne[2] == n_seq_tokens);
GGML_ASSERT(B->ne[3] == n_seqs);
GGML_ASSERT(ids->ne[0] == n_seqs);
GGML_ASSERT(ggml_is_vector(ids));
GGML_ASSERT(A->ne[1] == n_head);
GGML_ASSERT(ggml_is_matrix(A));
if (A->ne[0] != 1) {
// Mamba-1 has more granular decay factors
GGML_ASSERT(A->ne[0] == d_state);
}
}
// concatenated y + ssm_states
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]);
result->op = GGML_OP_SSM_SCAN;
result->src[0] = s;
@@ -4898,6 +4953,7 @@ struct ggml_tensor * ggml_ssm_scan(
result->src[3] = A;
result->src[4] = B;
result->src[5] = C;
result->src[6] = ids;
return result;
}
@@ -6038,13 +6094,28 @@ static void ggml_compute_backward(
}
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break;
case GGML_OP_GLU: {
switch (ggml_get_glu_op(tensor)) {
case GGML_GLU_OP_SWIGLU: {
if (src0_needs_grads) {
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
}
} break;
default: {
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
} //break;
}
} break;
case GGML_OP_NONE: {
// noop
} break;
case GGML_OP_COUNT:
default: {
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
GGML_ABORT("fatal error");
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
} //break;
}
+19
View File
@@ -170,6 +170,7 @@ class Keys:
INNER_SIZE = "{arch}.ssm.inner_size"
STATE_SIZE = "{arch}.ssm.state_size"
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
class WKV:
@@ -327,6 +328,7 @@ class MODEL_ARCH(IntEnum):
RWKV7 = auto()
ARWKV7 = auto()
MAMBA = auto()
MAMBA2 = auto()
XVERSE = auto()
COMMAND_R = auto()
COHERE2 = auto()
@@ -429,6 +431,7 @@ class MODEL_TENSOR(IntEnum):
SSM_DT = auto()
SSM_A = auto()
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
@@ -628,6 +631,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RWKV7: "rwkv7",
MODEL_ARCH.ARWKV7: "arwkv7",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.MAMBA2: "mamba2",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.COHERE2: "cohere2",
@@ -730,6 +734,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
@@ -1714,6 +1719,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.MAMBA2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_D,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_OUT,
],
MODEL_ARCH.XVERSE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -2497,6 +2515,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
# tokenization
+5 -2
View File
@@ -714,8 +714,8 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_shared_kv_layers(self, value: float) -> None:
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
def add_shared_kv_layers(self, value: int) -> None:
self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
@@ -861,6 +861,9 @@ class GGUFWriter:
def add_ssm_time_step_rank(self, value: int) -> None:
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
def add_ssm_group_count(self, value: int) -> None:
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
+5 -1
View File
@@ -477,7 +477,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm2", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
@@ -574,6 +574,10 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.D",
),
MODEL_TENSOR.SSM_NORM: (
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
+12 -3
View File
@@ -245,9 +245,18 @@ class SpecialVocab:
if not tokenizer_config:
return True
chat_template_alt = None
chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file():
with open(chat_template_file, encoding = 'utf-8') as f:
chat_template_json = path / 'chat_template.json'
chat_template_jinja = path / 'chat_template.jinja'
if chat_template_jinja.is_file():
with open(chat_template_jinja, encoding = 'utf-8') as f:
chat_template_alt = f.read()
if additional_templates := list((path / 'additional_chat_templates').glob('*.jinja')):
chat_template_alt = [{'name': 'default', 'template': chat_template_alt}]
for template_path in additional_templates:
with open(template_path, encoding = 'utf-8') as fp:
chat_template_alt.append({'name': template_path.stem, 'template': fp.read()})
elif chat_template_json.is_file():
with open(chat_template_json, encoding = 'utf-8') as f:
chat_template_alt = json.load(f).get('chat_template')
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
if chat_template is None or isinstance(chat_template, (str, list)):
-3
View File
@@ -83,7 +83,6 @@ while read c; do
src/ggml-cpu/* \
src/ggml-cuda/* \
src/ggml-hip/* \
src/ggml-kompute/* \
src/ggml-metal/* \
src/ggml-musa/* \
src/ggml-opencl/* \
@@ -141,7 +140,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
# src/ggml-cpu/* -> ggml/src/ggml-cpu/*
# src/ggml-cuda/* -> ggml/src/ggml-cuda/*
# src/ggml-hip/* -> ggml/src/ggml-hip/*
# src/ggml-kompute/* -> ggml/src/ggml-kompute/*
# src/ggml-metal/* -> ggml/src/ggml-metal/*
# src/ggml-musa/* -> ggml/src/ggml-musa/*
# src/ggml-opencl/* -> ggml/src/ggml-opencl/*
@@ -174,7 +172,6 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
-e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \
-e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
-1
View File
@@ -15,7 +15,6 @@ cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/
cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/
cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/
cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/
cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/
+20
View File
@@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
@@ -170,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@@ -1004,6 +1006,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_MAMBA2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_XVERSE,
{
@@ -1761,6 +1779,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -1894,6 +1913,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
bool llm_arch_is_recurrent(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_MAMBA:
case LLM_ARCH_MAMBA2:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
case LLM_ARCH_RWKV7:
+3
View File
@@ -49,6 +49,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,
@@ -174,6 +175,7 @@ enum llm_kv {
LLM_KV_SSM_CONV_KERNEL,
LLM_KV_SSM_STATE_SIZE,
LLM_KV_SSM_TIME_STEP_RANK,
LLM_KV_SSM_GROUP_COUNT,
LLM_KV_SSM_DT_B_C_RMS,
LLM_KV_WKV_HEAD_SIZE,
@@ -293,6 +295,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,

Some files were not shown because too many files have changed in this diff Show More