mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c80a7759da | |||
| 250d7953e8 | |||
| 403fbacbbc | |||
| a8a1f33567 | |||
| 1790e73157 | |||
| 0114a32da0 | |||
| a7724480fd | |||
| 1a85949067 | |||
| 6c02a032fa | |||
| f52d59d771 | |||
| 52de2e5949 | |||
| 2c3f8b850a | |||
| 4663bd353c | |||
| b3de7cac73 | |||
| 7242dd9675 | |||
| 492d7f1ff7 | |||
| d3f1f0acfb | |||
| 360dc22c00 | |||
| a62d7fa7a9 | |||
| e408d4351a | |||
| 3891e183c6 | |||
| af6ae1efb2 | |||
| 0bb2919335 | |||
| a69f846351 | |||
| d07a0d7a79 | |||
| 3714c3ee1a | |||
| b4ae50810e | |||
| b86f600723 | |||
| dd373dd3bf | |||
| 5d01670266 | |||
| ef03229ff4 | |||
| 13731766db | |||
| ab6ab8f809 | |||
| 2099a9d5db | |||
| 2969019837 | |||
| 5dec47dcd4 | |||
| f125b8dccf | |||
| 953c2a62cf | |||
| d5c6309d91 | |||
| 029c693fdc | |||
| 771d84371c | |||
| df0665a483 | |||
| 0306aad1ca | |||
| c7b43ab608 | |||
| 24feaec057 | |||
| f28bc4c286 | |||
| f17a3bb4e8 |
@@ -803,7 +803,7 @@ jobs:
|
||||
env:
|
||||
OPENBLAS_VERSION: 0.3.23
|
||||
SDE_VERSION: 9.33.0-2024-01-07
|
||||
VULKAN_VERSION: 1.4.304.1
|
||||
VULKAN_VERSION: 1.4.309.0
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
@@ -112,6 +112,8 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
||||
+1
-1
@@ -60,7 +60,7 @@ docker run --privileged -it \
|
||||
Inside the container, execute the following commands:
|
||||
|
||||
```bash
|
||||
apt update -y && apt install -y bc cmake git python3.10-venv time unzip wget
|
||||
apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget
|
||||
git config --global --add safe.directory /ws
|
||||
GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache
|
||||
```
|
||||
|
||||
@@ -69,7 +69,7 @@ fi
|
||||
if [ ! -z ${GG_BUILD_MUSA} ]; then
|
||||
# Use qy1 by default (MTT S80)
|
||||
MUSA_ARCH=${MUSA_ARCH:-21}
|
||||
CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
|
||||
fi
|
||||
## helpers
|
||||
|
||||
|
||||
+1
-1
@@ -1979,7 +1979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--host"}, "HOST",
|
||||
string_format("ip address to listen (default: %s)", params.hostname.c_str()),
|
||||
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.hostname = value;
|
||||
}
|
||||
|
||||
@@ -208,6 +208,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
if (!grmr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
|
||||
+134
-3
@@ -708,6 +708,12 @@ class Model:
|
||||
if chkhsh == "7dec86086fcc38b66b7bc1575a160ae21cf705be7718b9d5598190d7c12db76f":
|
||||
# ref: https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k
|
||||
res = "superbpe"
|
||||
if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15":
|
||||
# ref: https://huggingface.co/trillionlabs/Trillion-7B-preview
|
||||
res = "trillion"
|
||||
if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-lite
|
||||
res = "bailingmoe"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -2269,7 +2275,7 @@ class Qwen2Model(Model):
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@Model.register("Qwen2VLForConditionalGeneration")
|
||||
@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
|
||||
class Qwen2VLModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN2VL
|
||||
|
||||
@@ -3551,8 +3557,8 @@ class RWKV6Qwen2Model(Rwkv6Model):
|
||||
head_size = hidden_size // num_attention_heads
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
|
||||
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
|
||||
time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32)
|
||||
time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64)
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
@@ -4419,6 +4425,29 @@ class DeepseekV2Model(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("PLMForCausalLM")
|
||||
class PLMModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PLM
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["v_head_dim"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
|
||||
@Model.register("T5WithLMHeadModel")
|
||||
@Model.register("T5ForConditionalGeneration")
|
||||
@Model.register("MT5ForConditionalGeneration")
|
||||
@@ -5107,6 +5136,108 @@ class GraniteMoeModel(GraniteModel):
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@Model.register("BailingMoeForCausalLM")
|
||||
class BailingMoeModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_weights_scale(1.0)
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
@staticmethod
|
||||
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||
if n_head_kv is not None and n_head != n_head_kv:
|
||||
n_head = n_head_kv
|
||||
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
head_dim = self.hparams.get("head_dim", n_embd // n_head)
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
|
||||
if name.endswith("attention.dense.weight"):
|
||||
return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)]
|
||||
elif name.endswith("query_key_value.weight"):
|
||||
q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2)
|
||||
|
||||
return [
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)),
|
||||
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v)
|
||||
]
|
||||
elif name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if new_name == output_name and self.hparams.get("norm_head"):
|
||||
data_torch = data_torch.float()
|
||||
data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("ChameleonForConditionalGeneration")
|
||||
@Model.register("ChameleonForCausalLM") # obsolete
|
||||
class ChameleonModel(Model):
|
||||
|
||||
@@ -111,6 +111,8 @@ models = [
|
||||
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
|
||||
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
|
||||
{"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", },
|
||||
{"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", },
|
||||
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1396,14 +1396,16 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
const int n_kv = gguf_get_n_kv(ctx);
|
||||
const int ftype = get_u32(ctx, KEY_FTYPE);
|
||||
const std::string ftype_str = get_ftype(ftype);
|
||||
const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
|
||||
const std::string description = gguf_get_val_str(ctx, idx_desc);
|
||||
const int idx_name = gguf_find_key(ctx, KEY_NAME);
|
||||
if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
|
||||
const std::string name = gguf_get_val_str(ctx, idx_name);
|
||||
LOG_INF("%s: model name: %s\n", __func__, name.c_str());
|
||||
}
|
||||
LOG_INF("%s: description: %s\n", __func__, description.c_str());
|
||||
const int idx_desc = gguf_find_key(ctx, KEY_DESCRIPTION);
|
||||
if (idx_desc != -1) { // ditto
|
||||
const std::string description = gguf_get_val_str(ctx, idx_desc);
|
||||
LOG_INF("%s: description: %s\n", __func__, description.c_str());
|
||||
}
|
||||
LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
|
||||
LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
|
||||
LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors);
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_executable(rpc-server rpc-server.cpp)
|
||||
target_link_libraries(rpc-server PRIVATE ggml llama)
|
||||
set(TARGET rpc-server)
|
||||
add_executable(${TARGET} rpc-server.cpp)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -72,3 +72,14 @@ $ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name
|
||||
|
||||
This way you can offload model layers to both local and remote devices.
|
||||
|
||||
### Local cache
|
||||
|
||||
The RPC server can use a local cache to store large tensors and avoid transferring them over the network.
|
||||
This can speed up model loading significantly, especially when using large models.
|
||||
To enable the cache, use the `-c` option:
|
||||
|
||||
```bash
|
||||
$ bin/rpc-server -c
|
||||
```
|
||||
|
||||
By default, the cache is stored in the `$HOME/.cache/llama.cpp/rpc` directory and can be controlled via the `LLAMA_CACHE` environment variable.
|
||||
|
||||
+140
-6
@@ -1,3 +1,7 @@
|
||||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml-cpu.h"
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
@@ -18,26 +22,142 @@
|
||||
|
||||
#include "ggml-rpc.h"
|
||||
#ifdef _WIN32
|
||||
# define DIRECTORY_SEPARATOR '\\'
|
||||
# include <locale>
|
||||
# include <windows.h>
|
||||
# include <fcntl.h>
|
||||
# include <io.h>
|
||||
#else
|
||||
# define DIRECTORY_SEPARATOR '/'
|
||||
# include <unistd.h>
|
||||
# include <sys/stat.h>
|
||||
#endif
|
||||
#include <codecvt>
|
||||
#include <string>
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
// returns true if successful, false otherwise
|
||||
static bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wpath = converter.from_bytes(path);
|
||||
|
||||
// if the path already exists, check whether it's a directory
|
||||
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
||||
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t pos_slash = 0;
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||
const wchar_t * test = subpath.c_str();
|
||||
|
||||
const bool success = CreateDirectoryW(test, NULL);
|
||||
if (!success) {
|
||||
const DWORD error = GetLastError();
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (error == ERROR_ALREADY_EXISTS) {
|
||||
const DWORD attributes = GetFileAttributesW(subpath.c_str());
|
||||
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
#else
|
||||
// if the path already exists, check whether it's a directory
|
||||
struct stat info;
|
||||
if (stat(path.c_str(), &info) == 0) {
|
||||
return S_ISDIR(info.st_mode);
|
||||
}
|
||||
|
||||
size_t pos_slash = 1; // skip leading slashes for directory creation
|
||||
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
|
||||
const std::string subpath = path.substr(0, pos_slash);
|
||||
struct stat info;
|
||||
|
||||
// if the path already exists, ensure that it's a directory
|
||||
if (stat(subpath.c_str(), &info) == 0) {
|
||||
if (!S_ISDIR(info.st_mode)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// create parent directories
|
||||
const int ret = mkdir(subpath.c_str(), 0755);
|
||||
if (ret != 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
#endif // _WIN32
|
||||
}
|
||||
|
||||
// NOTE: this is copied from common.cpp to avoid linking with libcommon
|
||||
static std::string fs_get_cache_directory() {
|
||||
std::string cache_directory = "";
|
||||
auto ensure_trailing_slash = [](std::string p) {
|
||||
// Make sure to add trailing slash
|
||||
if (p.back() != DIRECTORY_SEPARATOR) {
|
||||
p += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
return p;
|
||||
};
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#ifdef __linux__
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else {
|
||||
cache_directory = std::getenv("HOME") + std::string("/.cache/");
|
||||
}
|
||||
#elif defined(__APPLE__)
|
||||
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
||||
#elif defined(_WIN32)
|
||||
cache_directory = std::getenv("LOCALAPPDATA");
|
||||
#endif // __linux__
|
||||
cache_directory = ensure_trailing_slash(cache_directory);
|
||||
cache_directory += "llama.cpp";
|
||||
}
|
||||
return ensure_trailing_slash(cache_directory);
|
||||
}
|
||||
|
||||
struct rpc_server_params {
|
||||
std::string host = "127.0.0.1";
|
||||
int port = 50052;
|
||||
size_t backend_mem = 0;
|
||||
bool use_cache = false;
|
||||
};
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
|
||||
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
|
||||
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
|
||||
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
|
||||
fprintf(stderr, " -c, --cache enable local file cache\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@@ -58,6 +178,8 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
|
||||
if (params.port <= 0 || params.port > 65535) {
|
||||
return false;
|
||||
}
|
||||
} else if (arg == "-c" || arg == "--cache") {
|
||||
params.use_cache = true;
|
||||
} else if (arg == "-m" || arg == "--mem") {
|
||||
if (++i >= argc) {
|
||||
return false;
|
||||
@@ -164,8 +286,20 @@ int main(int argc, char * argv[]) {
|
||||
} else {
|
||||
get_backend_memory(&free_mem, &total_mem);
|
||||
}
|
||||
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
|
||||
ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem);
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str = fs_get_cache_directory() + "rpc/";
|
||||
if (params.use_cache) {
|
||||
if (!fs_create_directory_with_parents(cache_dir_str)) {
|
||||
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
|
||||
return 1;
|
||||
}
|
||||
cache_dir = cache_dir_str.c_str();
|
||||
}
|
||||
printf("Starting RPC server\n");
|
||||
printf(" endpoint : %s\n", endpoint.c_str());
|
||||
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
|
||||
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
|
||||
ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
|
||||
ggml_backend_free(backend);
|
||||
return 0;
|
||||
}
|
||||
|
||||
+314
-248
File diff suppressed because it is too large
Load Diff
@@ -489,8 +489,12 @@ struct result_timings {
|
||||
double predicted_per_token_ms;
|
||||
double predicted_per_second;
|
||||
|
||||
// Optional speculative metrics - only included when > 0
|
||||
int32_t draft_n = 0;
|
||||
int32_t draft_n_accepted = 0;
|
||||
|
||||
json to_json() const {
|
||||
return {
|
||||
json base = {
|
||||
{"prompt_n", prompt_n},
|
||||
{"prompt_ms", prompt_ms},
|
||||
{"prompt_per_token_ms", prompt_per_token_ms},
|
||||
@@ -501,6 +505,13 @@ struct result_timings {
|
||||
{"predicted_per_token_ms", predicted_per_token_ms},
|
||||
{"predicted_per_second", predicted_per_second},
|
||||
};
|
||||
|
||||
if (draft_n > 0) {
|
||||
base["draft_n"] = draft_n;
|
||||
base["draft_n_accepted"] = draft_n_accepted;
|
||||
}
|
||||
|
||||
return base;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1299,6 +1310,10 @@ struct server_slot {
|
||||
|
||||
std::function<void(int)> callback_on_release;
|
||||
|
||||
// Speculative decoding stats
|
||||
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||
|
||||
void reset() {
|
||||
SLT_DBG(*this, "%s", "\n");
|
||||
|
||||
@@ -1315,6 +1330,10 @@ struct server_slot {
|
||||
|
||||
generated_tokens.clear();
|
||||
generated_token_probs.clear();
|
||||
|
||||
// clear speculative decoding stats
|
||||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
}
|
||||
|
||||
bool is_non_causal() const {
|
||||
@@ -1381,6 +1400,12 @@ struct server_slot {
|
||||
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
||||
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
||||
|
||||
// Add speculative metrics
|
||||
if (n_draft_total > 0) {
|
||||
timings.draft_n = n_draft_total;
|
||||
timings.draft_n_accepted = n_draft_accepted;
|
||||
}
|
||||
|
||||
return timings;
|
||||
}
|
||||
|
||||
@@ -1428,6 +1453,15 @@ struct server_slot {
|
||||
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
||||
t_token_generation, n_decoded, t_gen, n_gen_second,
|
||||
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
||||
|
||||
if (n_draft_total > 0) {
|
||||
const float draft_ratio = (float) n_draft_accepted / n_draft_total;
|
||||
SLT_INF(*this,
|
||||
"\n"
|
||||
"draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n",
|
||||
draft_ratio, n_draft_accepted, n_draft_total
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
json to_json() const {
|
||||
@@ -3290,6 +3324,9 @@ struct server_context {
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
||||
|
||||
// keep track of total number of tokens generated in the draft
|
||||
slot.n_draft_total += draft.size();
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
||||
@@ -3315,6 +3352,9 @@ struct server_context {
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
// update how many tokens out of draft was accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
|
||||
slot.cache_tokens.push_back(id);
|
||||
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
||||
|
||||
@@ -4459,15 +4499,24 @@ int main(int argc, char ** argv) {
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
// bind HTTP listen port
|
||||
bool was_bound = false;
|
||||
if (params.port == 0) {
|
||||
int bound_port = svr->bind_to_any_port(params.hostname);
|
||||
if ((was_bound = (bound_port >= 0))) {
|
||||
params.port = bound_port;
|
||||
}
|
||||
if (string_ends_with(std::string(params.hostname), ".sock")) {
|
||||
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
|
||||
svr->set_address_family(AF_UNIX);
|
||||
// bind_to_port requires a second arg, any value other than 0 should
|
||||
// simply get ignored
|
||||
was_bound = svr->bind_to_port(params.hostname, 8080);
|
||||
} else {
|
||||
was_bound = svr->bind_to_port(params.hostname, params.port);
|
||||
LOG_INF("%s: binding port with default address family\n", __func__);
|
||||
// bind HTTP listen port
|
||||
if (params.port == 0) {
|
||||
int bound_port = svr->bind_to_any_port(params.hostname);
|
||||
if ((was_bound = (bound_port >= 0))) {
|
||||
params.port = bound_port;
|
||||
}
|
||||
} else {
|
||||
was_bound = svr->bind_to_port(params.hostname, params.port);
|
||||
}
|
||||
}
|
||||
|
||||
if (!was_bound) {
|
||||
|
||||
@@ -699,11 +699,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
||||
const std::string voice_data = audio_data;
|
||||
|
||||
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
||||
printf("\n\n");
|
||||
|
||||
std::ostringstream tokens_oss;
|
||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||
printf("%d, ", tmp[i]);
|
||||
tokens_oss << tmp[i] << ", ";
|
||||
}
|
||||
printf("\n\n");
|
||||
LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
|
||||
|
||||
prompt_add(prompt_inp, tmp);
|
||||
#else
|
||||
prompt_add(prompt_inp, llama_tokens {
|
||||
|
||||
+7
-1
@@ -100,6 +100,10 @@ else()
|
||||
set(INS_ENB ON)
|
||||
endif()
|
||||
|
||||
message(DEBUG "GGML_NATIVE : ${GGML_NATIVE}")
|
||||
message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}")
|
||||
message(DEBUG "INS_ENB : ${INS_ENB}")
|
||||
|
||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
|
||||
@@ -123,10 +127,12 @@ endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
|
||||
option(GGML_VXE "ggml: enable vxe" ON)
|
||||
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
|
||||
set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC")
|
||||
|
||||
|
||||
if (WIN32)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
find_package(Git)
|
||||
|
||||
# the commit's SHA1
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" describe --match=NeVeRmAtCh --always --abbrev=8
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_SHA1
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the date of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%ad --date=local
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_DATE
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
# the subject of the commit
|
||||
execute_process(COMMAND
|
||||
"${GIT_EXECUTABLE}" log -1 --format=%s
|
||||
WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}"
|
||||
OUTPUT_VARIABLE GIT_COMMIT_SUBJECT
|
||||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@")
|
||||
set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@")
|
||||
set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@")
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
|
||||
@@ -17,7 +17,9 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
||||
const char * cache_dir,
|
||||
size_t free_mem, size_t total_mem);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||
|
||||
|
||||
+5
-5
@@ -1791,11 +1791,11 @@ extern "C" {
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd, n_batch, n_head, 1]
|
||||
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
||||
// q: [n_embd_k, n_batch, n_head, 1]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, 1]
|
||||
// v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd_v, n_head, n_batch, 1] !! permuted !!
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * q,
|
||||
|
||||
@@ -65,7 +65,7 @@ if (GGML_LTO)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_CCACHE)
|
||||
if (GGML_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER)
|
||||
find_program(GGML_CCACHE_FOUND ccache)
|
||||
find_program(GGML_SCCACHE_FOUND sccache)
|
||||
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
---
|
||||
Language: Cpp
|
||||
# BasedOnStyle: Google
|
||||
AccessModifierOffset: -1
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveMacros: false
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlines: Left
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllArgumentsOnNextLine: true
|
||||
AllowAllConstructorInitializersOnNextLine: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: Never
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortLambdasOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: WithoutElse
|
||||
AllowShortLoopsOnASingleLine: true
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
BinPackArguments: true
|
||||
BinPackParameters: true
|
||||
BraceWrapping:
|
||||
AfterCaseLabel: false
|
||||
AfterClass: false
|
||||
AfterControlStatement: false
|
||||
AfterEnum: false
|
||||
AfterFunction: false
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: false
|
||||
AfterStruct: false
|
||||
AfterUnion: false
|
||||
AfterExternBlock: false
|
||||
BeforeCatch: false
|
||||
BeforeElse: false
|
||||
IndentBraces: false
|
||||
SplitEmptyFunction: true
|
||||
SplitEmptyRecord: true
|
||||
SplitEmptyNamespace: true
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Attach
|
||||
BreakBeforeInheritanceComma: false
|
||||
BreakInheritanceList: BeforeColon
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
BreakConstructorInitializers: BeforeColon
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakStringLiterals: true
|
||||
ColumnLimit: 80
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DeriveLineEnding: true
|
||||
DerivePointerAlignment: true
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
FixNamespaceComments: true
|
||||
ForEachMacros:
|
||||
- foreach
|
||||
- Q_FOREACH
|
||||
- BOOST_FOREACH
|
||||
IncludeBlocks: Regroup
|
||||
IncludeCategories:
|
||||
- Regex: '^<ext/.*\.h>'
|
||||
Priority: 2
|
||||
SortPriority: 0
|
||||
- Regex: '^<.*\.h>'
|
||||
Priority: 1
|
||||
SortPriority: 0
|
||||
- Regex: '^<.*'
|
||||
Priority: 2
|
||||
SortPriority: 0
|
||||
- Regex: '.*'
|
||||
Priority: 3
|
||||
SortPriority: 0
|
||||
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
||||
IncludeIsMainSourceRegex: ''
|
||||
IndentCaseLabels: true
|
||||
IndentGotoLabels: true
|
||||
IndentPPDirectives: None
|
||||
IndentWidth: 4
|
||||
IndentWrappedFunctionNames: false
|
||||
JavaScriptQuotes: Leave
|
||||
JavaScriptWrapImports: true
|
||||
KeepEmptyLinesAtTheStartOfBlocks: false
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBinPackProtocolList: Never
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PenaltyBreakAssignment: 2
|
||||
PenaltyBreakBeforeFirstCallParameter: 1
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyBreakTemplateDeclaration: 10
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 200
|
||||
PointerAlignment: Left
|
||||
RawStringFormats:
|
||||
- Language: Cpp
|
||||
Delimiters:
|
||||
- cc
|
||||
- CC
|
||||
- cpp
|
||||
- Cpp
|
||||
- CPP
|
||||
- 'c++'
|
||||
- 'C++'
|
||||
CanonicalDelimiter: ''
|
||||
BasedOnStyle: google
|
||||
- Language: TextProto
|
||||
Delimiters:
|
||||
- pb
|
||||
- PB
|
||||
- proto
|
||||
- PROTO
|
||||
EnclosingFunctions:
|
||||
- EqualsProto
|
||||
- EquivToProto
|
||||
- PARSE_PARTIAL_TEXT_PROTO
|
||||
- PARSE_TEST_PROTO
|
||||
- PARSE_TEXT_PROTO
|
||||
- ParseTextOrDie
|
||||
- ParseTextProtoOrDie
|
||||
CanonicalDelimiter: ''
|
||||
BasedOnStyle: google
|
||||
ReflowComments: true
|
||||
SortIncludes: true
|
||||
SortUsingDeclarations: true
|
||||
SpaceAfterCStyleCast: false
|
||||
SpaceAfterLogicalNot: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCpp11BracedList: false
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeParens: ControlStatements
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceInEmptyBlock: false
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
SpacesInAngles: false
|
||||
SpacesInConditionalStatement: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
SpaceBeforeSquareBrackets: false
|
||||
Standard: Auto
|
||||
StatementMacros:
|
||||
- Q_UNUSED
|
||||
- QT_REQUIRE_VERSION
|
||||
TabWidth: 8
|
||||
UseCRLF: false
|
||||
UseTab: Never
|
||||
...
|
||||
|
||||
+12
-6
@@ -158,6 +158,12 @@ typedef sycl::half2 ggml_half2;
|
||||
|
||||
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define GGML_EXTENSION
|
||||
#else // _MSC_VER
|
||||
#define GGML_EXTENSION __extension__
|
||||
#endif // _MSC_VER
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
ggml_half d; // delta
|
||||
@@ -167,7 +173,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 b
|
||||
|
||||
#define QK4_1 32
|
||||
typedef struct {
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // delta
|
||||
ggml_half m; // min
|
||||
@@ -188,7 +194,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0
|
||||
|
||||
#define QK5_1 32
|
||||
typedef struct {
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // delta
|
||||
ggml_half m; // min
|
||||
@@ -209,7 +215,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block
|
||||
|
||||
#define QK8_1 32
|
||||
typedef struct {
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // delta
|
||||
ggml_half s; // d * sum(qs[i])
|
||||
@@ -250,7 +256,7 @@ static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0
|
||||
typedef struct {
|
||||
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
||||
uint8_t qs[QK_K/4]; // quants
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // super-block scale for quantized scales
|
||||
ggml_half dmin; // super-block scale for quantized mins
|
||||
@@ -277,7 +283,7 @@ static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // super-block scale for quantized scales
|
||||
ggml_half dmin; // super-block scale for quantized mins
|
||||
@@ -294,7 +300,7 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2,
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 5.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
GGML_EXTENSION union {
|
||||
struct {
|
||||
ggml_half d; // super-block scale for quantized scales
|
||||
ggml_half dmin; // super-block scale for quantized mins
|
||||
|
||||
@@ -23,6 +23,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
ggml-cpu/amx/mmq.cpp
|
||||
ggml-cpu/amx/mmq.h
|
||||
ggml-cpu/ggml-cpu-impl.h
|
||||
ggml-cpu/common.h
|
||||
ggml-cpu/binary-ops.h
|
||||
ggml-cpu/binary-ops.cpp
|
||||
ggml-cpu/unary-ops.h
|
||||
ggml-cpu/unary-ops.cpp
|
||||
)
|
||||
|
||||
target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17)
|
||||
@@ -289,23 +294,29 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
endif()
|
||||
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
||||
message(STATUS "PowerPC detected")
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
file(READ "/proc/cpuinfo" POWER10_M)
|
||||
elseif(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
|
||||
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
||||
endif()
|
||||
if (GGML_NATIVE)
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
file(READ "/proc/cpuinfo" POWER10_M)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc")
|
||||
execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M)
|
||||
endif()
|
||||
|
||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}")
|
||||
string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}")
|
||||
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
if (EXTRACTED_NUMBER GREATER_EQUAL 10)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10 -mpowerpc64)
|
||||
elseif (EXTRACTED_NUMBER EQUAL 9)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power9 -mpowerpc64)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le -mtune=native)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
||||
endif()
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
||||
if (GGML_CPU_POWERPC_CPUTYPE)
|
||||
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
message(STATUS "loongarch64 detected")
|
||||
@@ -320,7 +331,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64")
|
||||
message(STATUS "RISC-V detected")
|
||||
if (GGML_RVV)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
||||
if (GGML_RV_ZFH)
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
#include "binary-ops.h"
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
|
||||
using vDSP_fn_t = void (*)(const float *, vDSP_Stride, const float *, vDSP_Stride, float *, vDSP_Stride, vDSP_Length);
|
||||
#endif
|
||||
|
||||
static inline float op_add(float a, float b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
static inline float op_sub(float a, float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static inline float op_mul(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
static inline float op_div(float a, float b) {
|
||||
return a / b;
|
||||
}
|
||||
|
||||
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static inline void vec_binary_op_contiguous(const int64_t n, dst_t * z, const src0_t * x, const src1_t * y) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
|
||||
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(y[i])));
|
||||
}
|
||||
}
|
||||
|
||||
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static inline void vec_binary_op_non_contiguous(const int64_t n, const int64_t ne10, const int64_t nb10, dst_t * z, const src0_t * x, const src1_t * y) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
constexpr auto src1_to_f32 = type_conversion_table<src1_t>::to_f32;
|
||||
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
int i10 = i % ne10;
|
||||
const src1_t * y_ptr = (const src1_t *)((const char *)y + i10*nb10);
|
||||
z[i] = f32_to_dst(op(src0_to_f32(x[i]), src1_to_f32(*y_ptr)));
|
||||
}
|
||||
}
|
||||
|
||||
template <float (*op)(float, float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(dst_t));
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
||||
|
||||
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_fn_t vDSP_op = nullptr;
|
||||
// TODO - avoid the f32-only check using type 'trait' lookup tables and row-based src-to-float conversion functions
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
if (op == op_add) {
|
||||
vDSP_op = vDSP_vadd;
|
||||
} else if (op == op_sub) {
|
||||
vDSP_op = vDSP_vsub;
|
||||
} else if (op == op_mul) {
|
||||
vDSP_op = vDSP_vmul;
|
||||
} else if (op == op_div) {
|
||||
vDSP_op = vDSP_vdiv;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const int64_t i13 = i03 % ne13;
|
||||
const int64_t i12 = i02 % ne12;
|
||||
const int64_t i11 = i01 % ne11;
|
||||
|
||||
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
if (is_src1_contiguous) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
for (int64_t r = 0; r < nr0; ++r) {
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
if constexpr (std::is_same_v<src0_t, float> && std::is_same_v<src1_t, float> && std::is_same_v<dst_t, float>) {
|
||||
if (vDSP_op != nullptr) {
|
||||
vDSP_op(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
vec_binary_op_contiguous<op>(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
|
||||
}
|
||||
} else {
|
||||
vec_binary_op_non_contiguous<op>(ne0, ne10, nb10, dst_ptr, src0_ptr, src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
|
||||
template <float (*op)(float, float)>
|
||||
static void binary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
/* */ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
|
||||
apply_binary_op<op, float, float, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
|
||||
apply_binary_op<op, ggml_fp16_t, ggml_fp16_t, ggml_fp16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
|
||||
apply_binary_op<op, ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) {
|
||||
apply_binary_op<op, ggml_bf16_t, float, ggml_bf16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
apply_binary_op<op, ggml_bf16_t, float, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
|
||||
apply_binary_op<op, ggml_fp16_t, float, ggml_fp16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
apply_binary_op<op, ggml_fp16_t, float, float>(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_add_non_quantized(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
binary_op<op_add>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sub(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
binary_op<op_sub>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_mul(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
binary_op<op_mul>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_div(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
binary_op<op_div>(params, dst);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ggml_compute_forward_add_non_quantized(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sub(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_div(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu-traits.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <utility>
|
||||
|
||||
// convenience functions/macros for use in template calls
|
||||
// note: these won't be required after the 'traits' lookup table is used.
|
||||
static inline ggml_fp16_t f32_to_f16(float x) {
|
||||
return GGML_FP32_TO_FP16(x);
|
||||
}
|
||||
|
||||
static inline float f16_to_f32(ggml_fp16_t x) {
|
||||
return GGML_FP16_TO_FP32(x);
|
||||
}
|
||||
|
||||
static inline ggml_bf16_t f32_to_bf16(float x) {
|
||||
return GGML_FP32_TO_BF16(x);
|
||||
}
|
||||
|
||||
static inline float bf16_to_f32(ggml_bf16_t x) {
|
||||
return GGML_BF16_TO_FP32(x);
|
||||
}
|
||||
|
||||
static inline float f32_to_f32(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// TODO - merge this into the traits table, after using row-based conversions
|
||||
template <class T>
|
||||
struct type_conversion_table;
|
||||
|
||||
template <>
|
||||
struct type_conversion_table<ggml_fp16_t> {
|
||||
static constexpr float (*to_f32)(ggml_fp16_t) = f16_to_f32;
|
||||
static constexpr ggml_fp16_t (*from_f32)(float) = f32_to_f16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_conversion_table<float> {
|
||||
static constexpr float (*to_f32)(float) = f32_to_f32;
|
||||
static constexpr float (*from_f32)(float) = f32_to_f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_conversion_table<ggml_bf16_t> {
|
||||
static constexpr float (*to_f32)(ggml_bf16_t) = bf16_to_f32;
|
||||
static constexpr ggml_bf16_t (*from_f32)(float) = f32_to_bf16;
|
||||
};
|
||||
|
||||
static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_params * params, const struct ggml_tensor * src0) {
|
||||
const int64_t ith = params->ith;
|
||||
const int64_t nth = params->nth;
|
||||
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
// rows per thread
|
||||
const int64_t dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
return {ir0, ir1};
|
||||
}
|
||||
|
||||
#endif
|
||||
+746
-396
File diff suppressed because it is too large
Load Diff
+46
-2009
File diff suppressed because it is too large
Load Diff
@@ -55,6 +55,7 @@
|
||||
|
||||
#include <atomic>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define NOINLINE __declspec(noinline)
|
||||
@@ -1092,13 +1093,403 @@ class tinyBLAS_Q0_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
template<typename VA, typename VB, int size>
|
||||
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array<int, size>& comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
VB t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
aoffset = const_cast<TA*>(a);
|
||||
vecOffset = vec;
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 2);
|
||||
if (i > 0) {
|
||||
do {
|
||||
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
||||
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
||||
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
||||
c5[1] = reinterpret_cast<VB>(vec_xl(0, aoffset5->qs));
|
||||
c6[1] = reinterpret_cast<VB>(vec_xl(0, aoffset6->qs));
|
||||
c7[1] = reinterpret_cast<VB>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<VB>(vec_xl(0, aoffset8->qs));
|
||||
|
||||
c1[0] = vec_and(c1[1], lowMask);
|
||||
c1[1] = vec_sr(c1[1], v4);
|
||||
c1[0] = vec_sub(c1[0], v8);
|
||||
c1[1] = vec_sub(c1[1], v8);
|
||||
vsum = vec_sum4s(c1[0], vsum);
|
||||
vsum2 = vec_sum4s(c1[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c2[0] = vec_and(c2[1], lowMask);
|
||||
c2[1] = vec_sr(c2[1], v4);
|
||||
c2[0] = vec_sub(c2[0], v8);
|
||||
c2[1] = vec_sub(c2[1], v8);
|
||||
vsum = vec_sum4s(c2[0], vsum);
|
||||
vsum2 = vec_sum4s(c2[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c3[0] = vec_and(c3[1], lowMask);
|
||||
c3[1] = vec_sr(c3[1], v4);
|
||||
c3[0] = vec_sub(c3[0], v8);
|
||||
c3[1] = vec_sub(c3[1], v8);
|
||||
vsum = vec_sum4s(c3[0], vsum);
|
||||
vsum2 = vec_sum4s(c3[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c4[0] = vec_and(c4[1], lowMask);
|
||||
c4[1] = vec_sr(c4[1], v4);
|
||||
c4[0] = vec_sub(c4[0], v8);
|
||||
c4[1] = vec_sub(c4[1], v8);
|
||||
vsum = vec_sum4s(c4[0], vsum);
|
||||
vsum2 = vec_sum4s(c4[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c5[0] = vec_and(c5[1], lowMask);
|
||||
c5[1] = vec_sr(c5[1], v4);
|
||||
c5[0] = vec_sub(c5[0], v8);
|
||||
c5[1] = vec_sub(c5[1], v8);
|
||||
vsum = vec_sum4s(c5[0], vsum);
|
||||
vsum2 = vec_sum4s(c5[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c6[0] = vec_and(c6[1], lowMask);
|
||||
c6[1] = vec_sr(c6[1], v4);
|
||||
c6[0] = vec_sub(c6[0], v8);
|
||||
c6[1] = vec_sub(c6[1], v8);
|
||||
vsum = vec_sum4s(c6[0], vsum);
|
||||
vsum2 = vec_sum4s(c6[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c7[0] = vec_and(c7[1], lowMask);
|
||||
c7[1] = vec_sr(c7[1], v4);
|
||||
c7[0] = vec_sub(c7[0], v8);
|
||||
c7[1] = vec_sub(c7[1], v8);
|
||||
vsum = vec_sum4s(c7[0], vsum);
|
||||
vsum2 = vec_sum4s(c7[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c8[0] = vec_and(c8[1], lowMask);
|
||||
c8[1] = vec_sr(c8[1], v4);
|
||||
c8[0] = vec_sub(c8[0], v8);
|
||||
c8[1] = vec_sub(c8[1], v8);
|
||||
vsum = vec_sum4s(c8[0], vsum);
|
||||
vsum2 = vec_sum4s(c8[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
t1 = vec_perm(c1[0], c2[0], swiz1);
|
||||
t2 = vec_perm(c1[0], c2[0], swiz2);
|
||||
t3 = vec_perm(c3[0], c4[0], swiz1);
|
||||
t4 = vec_perm(c3[0], c4[0], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
|
||||
t1 = vec_perm(c1[1], c2[1], swiz1);
|
||||
t2 = vec_perm(c1[1], c2[1], swiz2);
|
||||
t3 = vec_perm(c3[1], c4[1], swiz1);
|
||||
t4 = vec_perm(c3[1], c4[1], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset+64);
|
||||
vec_xst(t6, 0, vecOffset+80);
|
||||
vec_xst(t7, 0, vecOffset+96);
|
||||
vec_xst(t8, 0, vecOffset+112);
|
||||
|
||||
t1 = vec_perm(c5[0], c6[0], swiz1);
|
||||
t2 = vec_perm(c5[0], c6[0], swiz2);
|
||||
t3 = vec_perm(c7[0], c8[0], swiz1);
|
||||
t4 = vec_perm(c7[0], c8[0], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset+128);
|
||||
vec_xst(t6, 0, vecOffset+144);
|
||||
vec_xst(t7, 0, vecOffset+160);
|
||||
vec_xst(t8, 0, vecOffset+176);
|
||||
|
||||
t1 = vec_perm(c5[1], c6[1], swiz1);
|
||||
t2 = vec_perm(c5[1], c6[1], swiz2);
|
||||
t3 = vec_perm(c7[1], c8[1], swiz1);
|
||||
t4 = vec_perm(c7[1], c8[1], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset+192);
|
||||
vec_xst(t6, 0, vecOffset+208);
|
||||
vec_xst(t7, 0, vecOffset+224);
|
||||
vec_xst(t8, 0, vecOffset+240);
|
||||
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
aoffset4 += lda;
|
||||
aoffset5 += lda;
|
||||
aoffset6 += lda;
|
||||
aoffset7 += lda;
|
||||
aoffset8 += lda;
|
||||
vecOffset += 256;
|
||||
i--;
|
||||
} while (i > 0);
|
||||
}
|
||||
j--;
|
||||
} while (j > 0);
|
||||
}
|
||||
|
||||
if (rows & 4) {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset += 4 * lda;
|
||||
|
||||
i = (cols >> 2);
|
||||
if (i > 0) {
|
||||
do {
|
||||
c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
||||
c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
||||
c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<VB>(vec_xl(0, aoffset4->qs));
|
||||
|
||||
c1[0] = vec_and(c1[1], lowMask);
|
||||
c1[1] = vec_sr(c1[1], v4);
|
||||
c1[0] = vec_sub(c1[0], v8);
|
||||
c1[1] = vec_sub(c1[1], v8);
|
||||
vsum = vec_sum4s(c1[0], vsum);
|
||||
vsum2 = vec_sum4s(c1[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c2[0] = vec_and(c2[1], lowMask);
|
||||
c2[1] = vec_sr(c2[1], v4);
|
||||
c2[0] = vec_sub(c2[0], v8);
|
||||
c2[1] = vec_sub(c2[1], v8);
|
||||
vsum = vec_sum4s(c2[0], vsum);
|
||||
vsum2 = vec_sum4s(c2[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c3[0] = vec_and(c3[1], lowMask);
|
||||
c3[1] = vec_sr(c3[1], v4);
|
||||
c3[0] = vec_sub(c3[0], v8);
|
||||
c3[1] = vec_sub(c3[1], v8);
|
||||
vsum = vec_sum4s(c3[0], vsum);
|
||||
vsum2 = vec_sum4s(c3[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c4[0] = vec_and(c4[1], lowMask);
|
||||
c4[1] = vec_sr(c4[1], v4);
|
||||
c4[0] = vec_sub(c4[0], v8);
|
||||
c4[1] = vec_sub(c4[1], v8);
|
||||
vsum = vec_sum4s(c4[0], vsum);
|
||||
vsum2 = vec_sum4s(c4[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats( 0);
|
||||
|
||||
t1 = vec_perm(c1[0], c2[0], swiz1);
|
||||
t2 = vec_perm(c1[0], c2[0], swiz2);
|
||||
t3 = vec_perm(c3[0], c4[0], swiz1);
|
||||
t4 = vec_perm(c3[0], c4[0], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
|
||||
t1 = vec_perm(c1[1], c2[1], swiz1);
|
||||
t2 = vec_perm(c1[1], c2[1], swiz2);
|
||||
t3 = vec_perm(c3[1], c4[1], swiz1);
|
||||
t4 = vec_perm(c3[1], c4[1], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset+64);
|
||||
vec_xst(t6, 0, vecOffset+80);
|
||||
vec_xst(t7, 0, vecOffset+96);
|
||||
vec_xst(t8, 0, vecOffset+112);
|
||||
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
aoffset4 += lda;
|
||||
vecOffset += 128;
|
||||
i--;
|
||||
} while (i > 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (rows & 3) {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
i = (cols >> 2);
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: c3[1] = reinterpret_cast<VB>(vec_xl(0, aoffset3->qs));
|
||||
case 2: c2[1] = reinterpret_cast<VB>(vec_xl(0, aoffset2->qs));
|
||||
case 1: c1[1] = reinterpret_cast<VB>(vec_xl(0, aoffset1->qs));
|
||||
break;
|
||||
}
|
||||
c1[0] = vec_and(c1[1], lowMask);
|
||||
c1[1] = vec_sr(c1[1], v4);
|
||||
c1[0] = vec_sub(c1[0], v8);
|
||||
c1[1] = vec_sub(c1[1], v8);
|
||||
vsum = vec_sum4s(c1[0], vsum);
|
||||
vsum2 = vec_sum4s(c1[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c2[0] = vec_and(c2[1], lowMask);
|
||||
c2[1] = vec_sr(c2[1], v4);
|
||||
c2[0] = vec_sub(c2[0], v8);
|
||||
c2[1] = vec_sub(c2[1], v8);
|
||||
vsum = vec_sum4s(c2[0], vsum);
|
||||
vsum2 = vec_sum4s(c2[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c3[0] = vec_and(c3[1], lowMask);
|
||||
c3[1] = vec_sr(c3[1], v4);
|
||||
c3[0] = vec_sub(c3[0], v8);
|
||||
c3[1] = vec_sub(c3[1], v8);
|
||||
vsum = vec_sum4s(c3[0], vsum);
|
||||
vsum2 = vec_sum4s(c3[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
c4[0] = vec_and(c4[1], lowMask);
|
||||
c4[1] = vec_sr(c4[1], v4);
|
||||
c4[0] = vec_sub(c4[0], v8);
|
||||
c4[1] = vec_sub(c4[1], v8);
|
||||
vsum = vec_sum4s(c4[0], vsum);
|
||||
vsum2 = vec_sum4s(c4[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
vsum = vec_splats(0);
|
||||
vsum2 = vec_splats(0);
|
||||
|
||||
t1 = vec_perm(c1[0], c2[0], swiz1);
|
||||
t2 = vec_perm(c1[0], c2[0], swiz2);
|
||||
t3 = vec_perm(c3[0], c4[0], swiz1);
|
||||
t4 = vec_perm(c3[0], c4[0], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
|
||||
t1 = vec_perm(c1[1], c2[1], swiz1);
|
||||
t2 = vec_perm(c1[1], c2[1], swiz2);
|
||||
t3 = vec_perm(c3[1], c4[1], swiz1);
|
||||
t4 = vec_perm(c3[1], c4[1], swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
vec_xst(t5, 0, vecOffset+64);
|
||||
vec_xst(t6, 0, vecOffset+80);
|
||||
vec_xst(t7, 0, vecOffset+96);
|
||||
vec_xst(t8, 0, vecOffset+112);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
vecOffset += 128;
|
||||
i--;
|
||||
} while(i > 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
int64_t i, j;
|
||||
TB *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
||||
VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
|
||||
VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
|
||||
@@ -1111,24 +1502,24 @@ class tinyBLAS_Q0_PPC {
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
|
||||
aoffset = const_cast<TA*>(a);
|
||||
aoffset = const_cast<TB*>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
|
||||
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
|
||||
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
||||
@@ -1156,10 +1547,10 @@ class tinyBLAS_Q0_PPC {
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
@@ -1175,10 +1566,10 @@ class tinyBLAS_Q0_PPC {
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset+64);
|
||||
vec_xst(t6, 0, vecOffset+80);
|
||||
@@ -1194,10 +1585,10 @@ class tinyBLAS_Q0_PPC {
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset+128);
|
||||
vec_xst(t6, 0, vecOffset+144);
|
||||
@@ -1213,10 +1604,10 @@ class tinyBLAS_Q0_PPC {
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset+192);
|
||||
vec_xst(t6, 0, vecOffset+208);
|
||||
@@ -1240,11 +1631,11 @@ class tinyBLAS_Q0_PPC {
|
||||
}
|
||||
|
||||
if (rows & 4) {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset += 4 * lda;
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset += 4 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
@@ -1311,7 +1702,7 @@ class tinyBLAS_Q0_PPC {
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
|
||||
@@ -1527,13 +1918,18 @@ class tinyBLAS_Q0_PPC {
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 4> comparray;
|
||||
std::array<int, 4> comparray {};
|
||||
vector float fin_res[8] = {0};
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
@@ -1545,15 +1941,17 @@ class tinyBLAS_Q0_PPC {
|
||||
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
}
|
||||
}
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
const int8_t *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
auto *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
|
||||
@@ -1565,13 +1963,18 @@ class tinyBLAS_Q0_PPC {
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[8] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 8> comparray;
|
||||
std::array<int, 8> comparray {};
|
||||
vector float fin_res[8] = {0};
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
@@ -1582,15 +1985,17 @@ class tinyBLAS_Q0_PPC {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
}
|
||||
}
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
const int8_t *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
auto *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
@@ -1602,15 +2007,20 @@ class tinyBLAS_Q0_PPC {
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
std::array<int, 8> comparray;
|
||||
std::array<int, 8> comparray {};
|
||||
vector float fin_res[16] = {0};
|
||||
vector float vs[16] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<int8_t, vector signed char, 8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
@@ -1624,15 +2034,17 @@ class tinyBLAS_Q0_PPC {
|
||||
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
}
|
||||
}
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
const int8_t *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
auto *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
@@ -1653,16 +2065,17 @@ class tinyBLAS_Q0_PPC {
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
vec_t vec_A[8], vec_B[8] = {0};
|
||||
vec_t vec_A[8] = {0}, vec_B[8] = {0};
|
||||
vector signed int vec_C[4];
|
||||
acc_t acc_0;
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
|
||||
if (end > tiles)
|
||||
end = tiles;
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
std::array<int, RM> comparray;
|
||||
std::array<int, 4> comparray{};
|
||||
vector float res[4] = {0};
|
||||
vector float fin_res[4] = {0};
|
||||
vector float vs[4] = {0};
|
||||
@@ -1673,7 +2086,11 @@ class tinyBLAS_Q0_PPC {
|
||||
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
packNormal<int8_t, vector signed char>((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
if (isAblock_q4) {
|
||||
packNormalInt4<int8_t, vector signed char, 4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x+=4) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
@@ -1687,17 +2104,18 @@ class tinyBLAS_Q0_PPC {
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
const int8_t *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
auto *at = aoffset->qs;
|
||||
for (int j = 0; j < 32; j++)
|
||||
ca += (int)*at++;
|
||||
comparray[i] = ca;
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < RM; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
@@ -2013,6 +2431,7 @@ class tinyBLAS_PPC {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[4], vec_B[4], vec_C[4];
|
||||
acc_t acc_0;
|
||||
@@ -2259,15 +2678,27 @@ class tinyBLAS_PPC {
|
||||
vec_t vec_C[4];
|
||||
acc_t acc_0;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
vec_t vec_A[4], vec_B[4];
|
||||
vec_t vec_A[4] {0}, vec_B[4] = {0};
|
||||
for (int l=0; l<k; l+=4) {
|
||||
if (RN >= 4 && RM == 1) {
|
||||
/* 'GEMV Forwarding' concept is used in first two conditional loops.
|
||||
* when one of the matrix has a single row/column, the elements are
|
||||
* broadcasted, instead of using packing routine to prepack the
|
||||
* matrix elements.
|
||||
*/
|
||||
if (RM == 1) {
|
||||
TA* a = const_cast<TA*>(A+(ii)*lda+l);
|
||||
packTranspose<vector float>(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
|
||||
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
||||
vec_A[0] = (vec_t)vec_xl(0,a);
|
||||
vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
|
||||
vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
|
||||
vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
|
||||
} else if (RN == 1) {
|
||||
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
||||
TB* b = const_cast<TB*>(B+(jj)*ldb+l);
|
||||
vec_B[0] = (vec_t)vec_xl(0,b);
|
||||
vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1));
|
||||
vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2));
|
||||
vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3));
|
||||
} else {
|
||||
packTranspose<vector float>(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
|
||||
packTranspose<vector float>(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
|
||||
@@ -2371,8 +2802,10 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
assert(params->ith < params->nth);
|
||||
|
||||
// only enable sgemm for prompt processing
|
||||
#if !defined(__MMA__)
|
||||
if (n < 2)
|
||||
return false;
|
||||
#endif
|
||||
|
||||
if (Ctype != GGML_TYPE_F32)
|
||||
return false;
|
||||
@@ -2503,8 +2936,8 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
|
||||
#elif defined(__MMA__)
|
||||
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
||||
if (n < 8 && n != 4)
|
||||
return false;
|
||||
if (m < 8 && m != 4)
|
||||
@@ -2516,7 +2949,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@@ -2541,6 +2973,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__MMA__)
|
||||
//TO-DO: Remove this condition once gemv forwarding is enabled.
|
||||
if (n < 8 && n != 4)
|
||||
return false;
|
||||
if (m < 8 && m != 4)
|
||||
return false;
|
||||
tinyBLAS_Q0_PPC<block_q4_0, block_q8_0, float> tb{
|
||||
k, (const block_q4_0 *)A, lda,
|
||||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
#include "unary-ops.h"
|
||||
|
||||
static inline float op_abs(float x) {
|
||||
return fabsf(x);
|
||||
}
|
||||
|
||||
static inline float op_sgn(float x) {
|
||||
return (x > 0.f) ? 1.f : ((x < 0.f) ? -1.f : 0.f);
|
||||
}
|
||||
|
||||
static inline float op_neg(float x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
static inline float op_step(float x) {
|
||||
return (x > 0.f) ? 1.f : 0.f;
|
||||
}
|
||||
|
||||
static inline float op_tanh(float x) {
|
||||
return tanhf(x);
|
||||
}
|
||||
|
||||
static inline float op_elu(float x) {
|
||||
return (x > 0.f) ? x : expm1f(x);
|
||||
}
|
||||
|
||||
static inline float op_relu(float x) {
|
||||
return (x > 0.f) ? x : 0.f;
|
||||
}
|
||||
|
||||
static inline float op_sigmoid(float x) {
|
||||
return 1.f / (1.f + expf(-x));
|
||||
}
|
||||
|
||||
static inline float op_hardsigmoid(float x) {
|
||||
return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
static inline float op_exp(float x) {
|
||||
return expf(x);
|
||||
}
|
||||
|
||||
static inline float op_hardswish(float x) {
|
||||
return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
static inline float op_sqr(float x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
static inline float op_sqrt(float x) {
|
||||
return sqrtf(x);
|
||||
}
|
||||
|
||||
static inline float op_sin(float x) {
|
||||
return sinf(x);
|
||||
}
|
||||
|
||||
static inline float op_cos(float x) {
|
||||
return cosf(x);
|
||||
}
|
||||
|
||||
static inline float op_log(float x) {
|
||||
return logf(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
y[i] = f32_to_dst(op(src0_to_f32(x[i])));
|
||||
}
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(dst_t));
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
|
||||
for (int64_t ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
|
||||
vec_unary_op<op>(ne0, dst_ptr, src0_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Use the 'traits' lookup table (for type conversion fns), instead of a mass of 'if' conditions with long templates
|
||||
template <float (*op)(float)>
|
||||
static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
/* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
|
||||
apply_unary_op<op, float, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
|
||||
apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
|
||||
apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op<op, ggml_bf16_t, float>(params, dst);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
apply_unary_op<op, ggml_fp16_t, float>(params, dst);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_abs>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sgn(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sgn>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_neg(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_neg>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_step(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_step>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_tanh(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_tanh>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_elu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_elu>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_relu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_relu>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sigmoid>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_hardsigmoid(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_hardsigmoid>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_exp(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_exp>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_hardswish(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_hardswish>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sqr(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sqr>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sqrt(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sqrt>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_sin(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_sin>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_cos>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void ggml_compute_forward_abs(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sgn(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_neg(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_step(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_tanh(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_elu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_hardsigmoid(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_exp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_hardswish(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sqr(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -288,6 +288,10 @@ static __device__ void no_device_code(
|
||||
__trap();
|
||||
|
||||
GGML_UNUSED(no_device_code); // suppress unused function warning
|
||||
|
||||
#if defined(GGML_USE_MUSA)
|
||||
__builtin_unreachable();
|
||||
#endif // defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
|
||||
@@ -38,7 +38,7 @@ static __global__ void concat_f32_dim1(const float * x, const float * y, float *
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
|
||||
if (blockIdx.y < ne01) { // src0
|
||||
if (blockIdx.y < (unsigned)ne01) { // src0
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
@@ -64,7 +64,7 @@ static __global__ void concat_f32_dim2(const float * x, const float * y, float *
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
|
||||
if (blockIdx.z < ne02) { // src0
|
||||
if (blockIdx.z < (unsigned)ne02) { // src0
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
|
||||
@@ -34,6 +34,10 @@ static __global__ void conv_transpose_1d_kernel(
|
||||
}
|
||||
}
|
||||
dst[global_index] = accumulator;
|
||||
GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
|
||||
GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
|
||||
GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
|
||||
GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
|
||||
}
|
||||
|
||||
static void conv_transpose_1d_f32_f32_cuda(
|
||||
@@ -75,8 +79,6 @@ void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
const int p0 = 0;//opts[3];
|
||||
const int d0 = 1;//opts[4];
|
||||
|
||||
const int64_t kernel_size = ggml_nelements(src0);
|
||||
const int64_t input_size = ggml_nelements(src1);
|
||||
const int64_t output_size = ggml_nelements(dst);
|
||||
|
||||
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
|
||||
|
||||
@@ -577,7 +577,7 @@ static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __res
|
||||
return;
|
||||
}
|
||||
|
||||
const src_t * x = (src_t *) vx;
|
||||
const src_t * x = (const src_t *) vx;
|
||||
|
||||
y[i] = x[i];
|
||||
}
|
||||
|
||||
@@ -315,14 +315,14 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
|
||||
float vals[sizeof(int)] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
vals[l] = scale * x[4*threadIdx.x + l];
|
||||
}
|
||||
|
||||
float amax = fabsf(vals[0]);
|
||||
float sum = vals[0];
|
||||
#pragma unroll
|
||||
for (int l = 1; l < sizeof(int); ++l) {
|
||||
for (int l = 1; l < int(sizeof(int)); ++l) {
|
||||
amax = fmaxf(amax, fabsf(vals[l]));
|
||||
sum += vals[l];
|
||||
}
|
||||
@@ -338,7 +338,7 @@ static __device__ __forceinline__ void quantize_q8_1_to_shared(
|
||||
|
||||
if (d != 0.0f) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
q8[l] = roundf(vals[l] / d);
|
||||
}
|
||||
}
|
||||
@@ -638,7 +638,7 @@ static __global__ void flash_attn_combine_results(
|
||||
float VKQ_denominator = 0.0f;
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
const float KQ_max_scale = expf(diff);
|
||||
float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
@@ -649,6 +649,7 @@ static __global__ void flash_attn_combine_results(
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static void on_no_fattn_vec_case(const int D) {
|
||||
if (D == 64) {
|
||||
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
|
||||
|
||||
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||
#endif // CP_ASYNC_AVAILABLE
|
||||
|
||||
#else
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
|
||||
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
||||
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
|
||||
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
|
||||
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
|
||||
GGML_UNUSED(kb0);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
__syncthreads();
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
||||
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
|
||||
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
|
||||
}
|
||||
@@ -985,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
||||
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
|
||||
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
|
||||
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
|
||||
|
||||
// Kernels with ncols == 128 are only 4% faster due to register pressure.
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
|
||||
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
|
||||
|
||||
@@ -282,7 +282,19 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
|
||||
@@ -281,6 +281,18 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -292,7 +292,19 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
|
||||
}
|
||||
|
||||
|
||||
@@ -277,6 +277,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
|
||||
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
@@ -430,7 +430,17 @@ static __global__ void flash_attn_ext_f16(
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#include "ggml-cuda/rope.cuh"
|
||||
#include "ggml-cuda/scale.cuh"
|
||||
#include "ggml-cuda/softmax.cuh"
|
||||
#include "ggml-cuda/ssm-conv.cuh"
|
||||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
@@ -2296,6 +2298,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
ggml_cuda_op_ssm_scan(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
@@ -3193,6 +3201,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LOG:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
@@ -3232,6 +3242,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
|
||||
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
|
||||
: "=r"(ret) : "r"(x));
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(NEW_MMA_AVAILABLE)
|
||||
return ret;
|
||||
@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
|
||||
: "l"(xs));
|
||||
#else
|
||||
load_generic(xs0, stride);
|
||||
GGML_UNUSED(t);
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
|
||||
+38
-22
@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
||||
for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
#pragma unroll
|
||||
@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
if (k01 < WARP_SIZE/2) {
|
||||
constexpr int ns = 2;
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
||||
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
||||
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
||||
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
} else {
|
||||
constexpr int ns = 1;
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
||||
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
||||
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
||||
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
}
|
||||
constexpr int ns = 2;
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
||||
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
||||
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
||||
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
|
||||
// As a workaround 2 separate loops are used instead.
|
||||
#pragma unroll
|
||||
for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
constexpr int ns = 1;
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
|
||||
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
|
||||
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
|
||||
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
@@ -1253,7 +1268,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const float d = bxi->d;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
|
||||
}
|
||||
#else
|
||||
@@ -1376,7 +1391,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
||||
}
|
||||
}
|
||||
@@ -1517,7 +1532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < sizeof(int); ++l) {
|
||||
for (int l = 0; l < int(sizeof(int)); ++l) {
|
||||
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
|
||||
}
|
||||
}
|
||||
@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
|
||||
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
|
||||
} else {
|
||||
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
|
||||
}
|
||||
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne10);
|
||||
}
|
||||
|
||||
|
||||
@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
||||
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
|
||||
|
||||
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
|
||||
if (it != blockIdx.x || jt != blockIdx.y) {
|
||||
if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -2825,7 +2842,6 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
template <ggml_type type>
|
||||
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf;
|
||||
float sumf = 0.0f;
|
||||
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
|
||||
@@ -151,7 +151,7 @@ static __global__ void mul_mat_vec_q(
|
||||
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
|
||||
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
@@ -197,10 +197,12 @@ static __global__ void mul_mat_vec_q(
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
|
||||
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(nrows_x);
|
||||
}
|
||||
|
||||
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
|
||||
|
||||
@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
|
||||
nidx +
|
||||
blockIdx.y * ne0 +
|
||||
blockIdx.z * ne0 * gridDim.y;
|
||||
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
|
||||
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
|
||||
int offset_src =
|
||||
nidx +
|
||||
blockIdx.y * ne00 +
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
#include "ssm-conv.cuh"
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv>
|
||||
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
|
||||
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
|
||||
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
const int bidy = blockIdx.y;
|
||||
|
||||
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
|
||||
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
|
||||
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
|
||||
|
||||
const int stride_x = src0_nb1 / sizeof(float);
|
||||
const int stride_w = src1_nb1 / sizeof(float);
|
||||
const int stride_y = dst_nb1 / sizeof(float);
|
||||
|
||||
float x[d_conv] = { 0.0f };
|
||||
float w[d_conv] = { 0.0f };
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_t; i++) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
x[j] = x_block[tid * stride_x + j];
|
||||
}
|
||||
} else {
|
||||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
template <size_t split_d_inner, size_t d_conv, size_t split_n_t>
|
||||
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
|
||||
const int src0_nb0, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
|
||||
const int dst_nb1, const int dst_nb2, const int nc, const int ncs,
|
||||
const int nr, const int n_t, const int n_s) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
const int bidy = blockIdx.y;
|
||||
const int bidz = blockIdx.z;
|
||||
|
||||
const float * x_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
|
||||
bidz * split_n_t * src0_nb0);
|
||||
const float * w_block = (const float *) ((char *) src1 + bidy * split_d_inner * src1_nb1);
|
||||
float * y_block =
|
||||
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
|
||||
|
||||
const int stride_x = src0_nb1 / sizeof(float);
|
||||
const int stride_w = src1_nb1 / sizeof(float);
|
||||
const int stride_y = dst_nb1 / sizeof(float);
|
||||
|
||||
float x[d_conv] = { 0.0f };
|
||||
float w[d_conv] = { 0.0f };
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
w[j] = w_block[tid * stride_w + j];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < split_n_t; i++) {
|
||||
if (bidz * split_n_t + i < n_t) {
|
||||
float sumf = 0.0f;
|
||||
|
||||
if (i == 0) {
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
x[j] = x_block[tid * stride_x + j];
|
||||
}
|
||||
} else {
|
||||
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < d_conv; j++) {
|
||||
sumf += x[(i + j) % d_conv] * w[j];
|
||||
}
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
|
||||
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
|
||||
const int dst_nb2, const int nc, const int ncs, const int nr, const int n_t,
|
||||
const int n_s, cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
GGML_ASSERT(nr % threads == 0);
|
||||
|
||||
if (n_t <= 32) {
|
||||
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
|
||||
if (nc == 4) {
|
||||
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
|
||||
dst, dst_nb0, dst_nb1, dst_nb2, nc, ncs, nr, n_t,
|
||||
n_s);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 now.");
|
||||
}
|
||||
} else {
|
||||
if (nc == 4) {
|
||||
const int split_n_t = 32;
|
||||
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
|
||||
ssm_conv_long_token_f32<threads, 4, split_n_t>
|
||||
<<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0,
|
||||
dst_nb1, dst_nb2, nc, ncs, nr, n_t, n_s);
|
||||
} else {
|
||||
GGML_ABORT("Only support kernel size = 4 right now.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||
|
||||
const int nc = src1->ne[0]; // d_conv
|
||||
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||
const int nr = src0->ne[1]; // d_inner
|
||||
const int n_t = dst->ne[1]; // tokens per sequence
|
||||
const int n_s = dst->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(dst->ne[0] == nr);
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
|
||||
dst->nb[2], nc, ncs, nr, n_t, n_s, stream);
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -0,0 +1,155 @@
|
||||
#include "ssm-scan.cuh"
|
||||
|
||||
// #include <cuda_runtime.h>
|
||||
// static __device__ void global_to_shared(const float *src, float *dst) {
|
||||
// asm volatile("cp.async.");
|
||||
// }
|
||||
|
||||
template <size_t splitD, size_t N>
|
||||
__global__ void __launch_bounds__(splitD, 2)
|
||||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
|
||||
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
||||
float * __restrict__ dst, const int D, const int L, const int B) {
|
||||
const int bidx = blockIdx.x; // split along B
|
||||
const int bidy = blockIdx.y; // split along D
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int wtid = tid % 32;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
const int stride_sA = N + 1;
|
||||
const int stride_ss0 = N + 1;
|
||||
float * smem_A = smem;
|
||||
float * smem_s0 = smem_A + splitD * stride_sA;
|
||||
|
||||
const float * s0_block = (const float *) ((char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
|
||||
const float * x_block = (const float *) ((char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
|
||||
const float * dt_block = (const float *) ((char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
||||
const float * A_block = (const float *) ((char *) src3 + bidy * splitD * src3_nb1);
|
||||
const float * B_block = (const float *) ((char *) src4 + (bidx * src4_nb2));
|
||||
const float * C_block = (const float *) ((char *) src5 + (bidx * src5_nb2));
|
||||
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
|
||||
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
|
||||
|
||||
const int stride_s0 = src0_nb1 / sizeof(float);
|
||||
const int stride_x = src1_nb1 / sizeof(float);
|
||||
const int stride_dt = src2_nb1 / sizeof(float);
|
||||
const int stride_A = src3_nb1 / sizeof(float);
|
||||
const int stride_B = src4_nb1 / sizeof(float);
|
||||
const int stride_C = src5_nb1 / sizeof(float);
|
||||
const int stride_s = stride_s0;
|
||||
const int stride_y = stride_x;
|
||||
|
||||
// can N not be 16? for example 32?
|
||||
if (N == 16) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < splitD / 4; i += 2) {
|
||||
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
|
||||
// todo: bank conflict
|
||||
// I am always confused with how to use the swizzling method to solve
|
||||
// bank conflit. Hoping somebody can tell me.
|
||||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < splitD / 4; i += 2) {
|
||||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int i = 0; i < L; i++) {
|
||||
float dt_soft_plus = dt_block[i * stride_dt + tid];
|
||||
if (dt_soft_plus <= 20.0f) {
|
||||
dt_soft_plus = log1pf(exp(dt_soft_plus));
|
||||
}
|
||||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < N; j++) {
|
||||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
|
||||
(B_block[i * stride_B + j] * x_dt);
|
||||
sumf += state * C_block[i * stride_C + j];
|
||||
if (i == L - 1) {
|
||||
s_block[tid * stride_s + j] = state;
|
||||
} else {
|
||||
smem_s0[tid * stride_ss0 + j] = state;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
}
|
||||
}
|
||||
|
||||
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
|
||||
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
|
||||
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
|
||||
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
|
||||
float * dst, const int N, const int D, const int L, const int B, cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// todo: consider D cannot be divided,does this situation exist?
|
||||
GGML_ASSERT(D % threads == 0);
|
||||
const dim3 blocks(B, (D + threads - 1) / threads, 1);
|
||||
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
|
||||
if (N == 16) {
|
||||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
|
||||
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, D, L, B);
|
||||
} else {
|
||||
GGML_ABORT("doesn't support N!=16.");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0]; // s
|
||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
||||
const struct ggml_tensor * src2 = dst->src[2]; // dt
|
||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||
|
||||
// const int64_t d_state = src0->ne[0];
|
||||
// const int64_t d_inner = src0->ne[1];
|
||||
// const int64_t l = src1->ne[1];
|
||||
// const int64_t b = src0->ne[2];
|
||||
|
||||
const int64_t nc = src0->ne[0]; // d_state
|
||||
const int64_t nr = src0->ne[1]; // d_inner
|
||||
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||
|
||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||
// required for the dot product between s and C
|
||||
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
|
||||
// required for per-sequence offsets for states
|
||||
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
|
||||
// required to get correct offset for state destination (i.e. src1->nb[3])
|
||||
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
const float * src2_d = (const float *) src2->data;
|
||||
const float * src3_d = (const float *) src3->data;
|
||||
const float * src4_d = (const float *) src4->data;
|
||||
const float * src5_d = (const float *) src5->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
|
||||
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
|
||||
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -19,7 +19,7 @@ static __global__ void upscale_f32(const float * x, float * dst,
|
||||
int i02 = i12 / sf2;
|
||||
int i03 = i13 / sf3;
|
||||
|
||||
dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
|
||||
dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
|
||||
}
|
||||
|
||||
static void upscale_f32_cuda(const float * x, float * dst,
|
||||
|
||||
@@ -381,6 +381,35 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size);
|
||||
return r;
|
||||
}
|
||||
|
||||
#elif defined(__riscv) && defined(GGML_RV_ZFH)
|
||||
|
||||
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
float f;
|
||||
__asm__(
|
||||
"fmv.h.x %[f], %[h]\n\t"
|
||||
"fcvt.s.h %[f], %[f]"
|
||||
: [f] "=&f" (f)
|
||||
: [h] "r" (h)
|
||||
);
|
||||
return f;
|
||||
}
|
||||
|
||||
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
||||
ggml_fp16_t res;
|
||||
__asm__(
|
||||
"fcvt.h.s %[f], %[f]\n\t"
|
||||
"fmv.x.h %[h], %[f]"
|
||||
: [h] "=&r" (res)
|
||||
: [f] "f" (f)
|
||||
);
|
||||
return res;
|
||||
}
|
||||
|
||||
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
||||
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
||||
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
||||
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
||||
|
||||
#else
|
||||
|
||||
// FP16 <-> FP32
|
||||
|
||||
@@ -219,9 +219,12 @@ typedef struct {
|
||||
int32_t ne11;
|
||||
int32_t ne_12_2; // assume K and V are same shape
|
||||
int32_t ne_12_3;
|
||||
uint64_t nb_12_1;
|
||||
uint64_t nb_12_2;
|
||||
uint64_t nb_12_3;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
uint64_t nb31;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
|
||||
+543
-412
File diff suppressed because it is too large
Load Diff
@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src + il));
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
@@ -56,6 +56,11 @@ template <typename type4x4>
|
||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*(src));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename type4x4>
|
||||
@@ -3100,7 +3105,8 @@ template<
|
||||
typename vd4x4_t, // key type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
||||
short D, // head size
|
||||
short DK, // K head size
|
||||
short DV, // V head size
|
||||
short Q = 8, // queries per threadgroup
|
||||
short KV = 8, // key/value processed per each simdgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
@@ -3122,20 +3128,24 @@ kernel void kernel_flash_attn_ext(
|
||||
const int iq2 = tgpig[1];
|
||||
const int iq1 = tgpig[0]*Q;
|
||||
|
||||
const short D4 = D/4;
|
||||
const short D8 = D/8;
|
||||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
||||
constexpr short DK4 = DK/4;
|
||||
constexpr short DK8 = DK/8;
|
||||
constexpr short DK16 = DK/16;
|
||||
constexpr short DV4 = DV/4;
|
||||
constexpr short DV8 = DV/8;
|
||||
constexpr short DV16 = DV/16;
|
||||
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
||||
|
||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||
const short T = D + 2*TS; // shared memory size per query in (half)
|
||||
const short T = DK + 2*TS; // shared memory size per query in (half)
|
||||
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
|
||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
||||
|
||||
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||
@@ -3144,23 +3154,23 @@ kernel void kernel_flash_attn_ext(
|
||||
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
o8x8_t lo[D8];
|
||||
o8x8_t lo[DV8];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
for (short j = sgitg; j < Q; j += nsg) {
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
for (short i = tiisg; i < DK4; i += NW) {
|
||||
if (iq1 + j < args.ne01) {
|
||||
sq4[j*D4 + i] = (q4_t) q4[i];
|
||||
sq4[j*DK4 + i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[j*D4 + i] = (q4_t) 0.0f;
|
||||
sq4[j*DK4 + i] = (q4_t) 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
||||
}
|
||||
|
||||
@@ -3190,13 +3200,6 @@ kernel void kernel_flash_attn_ext(
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
q8x8_t mq[D8];
|
||||
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_load(mq[i], sq + i*8, D);
|
||||
}
|
||||
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
half slope = 1.0f;
|
||||
@@ -3249,20 +3252,22 @@ kernel void kernel_flash_attn_ext(
|
||||
// this is compile-time check, so it does not have runtime overhead
|
||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
||||
|
||||
#pragma unroll(D8)
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
#pragma unroll(DK8)
|
||||
for (short i = 0; i < DK8; ++i) {
|
||||
k8x8_t mk;
|
||||
simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
||||
|
||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||
q8x8_t mq;
|
||||
simdgroup_load(mq, sq + i*8, DK);
|
||||
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
for (short ii = 0; ii < DK16; ii += 4) {
|
||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
if (DK16%4 == 0) {
|
||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||
{
|
||||
k4x4_t tmp;
|
||||
@@ -3275,15 +3280,18 @@ kernel void kernel_flash_attn_ext(
|
||||
#pragma unroll(4)
|
||||
for (short k = 0; k < 4; ++k) {
|
||||
k8x8_t mk;
|
||||
q8x8_t mq;
|
||||
|
||||
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
||||
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
||||
|
||||
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
||||
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
||||
}
|
||||
} else {
|
||||
if (ii + tx < D16) {
|
||||
if (ii + tx < DK16) {
|
||||
k4x4_t tmp;
|
||||
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
||||
sk4x4[4*ty + tx] = tmp;
|
||||
@@ -3291,14 +3299,17 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||
for (short k = 0; k < 4 && ii + k < DK16; ++k) {
|
||||
k8x8_t mk;
|
||||
q8x8_t mq;
|
||||
|
||||
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
||||
simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
|
||||
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
||||
|
||||
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
||||
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
||||
simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
|
||||
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3350,8 +3361,8 @@ kernel void kernel_flash_attn_ext(
|
||||
s8x8_t mm;
|
||||
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
||||
|
||||
#pragma unroll(D8)
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
#pragma unroll(DV8)
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
simdgroup_multiply(lo[i], mm, lo[i]);
|
||||
}
|
||||
}
|
||||
@@ -3364,20 +3375,20 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
if (is_same<vd4x4_t, v4x4_t>::value) {
|
||||
// we can read directly from global memory
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
||||
|
||||
#pragma unroll(D8)
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
#pragma unroll(DV8)
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
v8x8_t mv;
|
||||
simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
||||
}
|
||||
} else {
|
||||
for (short ii = 0; ii < D16; ii += 4) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
for (short ii = 0; ii < DV16; ii += 4) {
|
||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
||||
|
||||
if (D16%4 == 0) {
|
||||
if (DV16%4 == 0) {
|
||||
// no need for bound checks
|
||||
{
|
||||
v4x4_t tmp;
|
||||
@@ -3398,7 +3409,7 @@ kernel void kernel_flash_attn_ext(
|
||||
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
||||
}
|
||||
} else {
|
||||
if (ii + tx < D16) {
|
||||
if (ii + tx < DV16) {
|
||||
v4x4_t tmp;
|
||||
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
||||
sv4x4[4*ty + tx] = tmp;
|
||||
@@ -3406,7 +3417,7 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
simdgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
||||
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
|
||||
v8x8_t mv;
|
||||
|
||||
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
||||
@@ -3440,8 +3451,8 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// each simdgroup stores its output to shared memory, reusing sq
|
||||
if (sgitg == sg) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3480,11 +3491,11 @@ kernel void kernel_flash_attn_ext(
|
||||
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
||||
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
||||
|
||||
#pragma unroll(D8)
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
#pragma unroll(DV8)
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
o8x8_t t;
|
||||
|
||||
simdgroup_load (t, so + i*8, D, 0, false);
|
||||
simdgroup_load (t, so + i*8, DV, 0, false);
|
||||
simdgroup_multiply(t, ms1, t);
|
||||
|
||||
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
||||
@@ -3495,8 +3506,8 @@ kernel void kernel_flash_attn_ext(
|
||||
|
||||
// store result to shared memory (reuse sq)
|
||||
if (sgitg == 0) {
|
||||
for (short i = 0; i < D8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
||||
for (short i = 0; i < DV8; ++i) {
|
||||
simdgroup_store(lo[i], so + i*8, DV, 0, false);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3507,8 +3518,8 @@ kernel void kernel_flash_attn_ext(
|
||||
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3525,80 +3536,94 @@ kernel void kernel_flash_attn_ext(
|
||||
float, simdgroup_float8x8, \
|
||||
half, half4, simdgroup_half8x8
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
template<
|
||||
typename q4_t, // query types in shared memory
|
||||
typename q4x4_t,
|
||||
typename k4x4_t, // key types in shared memory
|
||||
typename v4x4_t, // value types in shared memory
|
||||
typename qk_t, // Q*K types
|
||||
typename s_t, // soft-max types
|
||||
typename q4_t, // query types in shared memory
|
||||
typename k4_t, // key types in shared memory
|
||||
typename v4_t, // value types in shared memory
|
||||
typename qk_t, // Q*K types
|
||||
typename s_t, // soft-max types
|
||||
typename s4_t,
|
||||
typename s4x4_t,
|
||||
typename o4x4_t, // attention accumulation types
|
||||
typename kd4x4_t, // key type in device memory
|
||||
typename o4_t, // attention accumulation types
|
||||
typename kd4_t, // key type in device memory
|
||||
short nl_k,
|
||||
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
||||
typename vd4x4_t, // key type in device memory
|
||||
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
||||
typename vd4_t, // key type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
||||
short D, // head size
|
||||
short Q = 1, // queries per threadgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
||||
short DK, // K head size
|
||||
short DV, // V head size
|
||||
short NE = 4, // head elements per thread
|
||||
short Q = 1, // queries per threadgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
device const char * q,
|
||||
@@ -3617,29 +3642,28 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const int iq2 = tgpig[1];
|
||||
const int iq1 = tgpig[0];
|
||||
|
||||
const short D4 = D/4;
|
||||
const short D16 = D/16;
|
||||
const short NW = N_SIMDWIDTH;
|
||||
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
|
||||
const short SH = 2*C; // shared memory per simdgroup
|
||||
constexpr short DK4 = DK/4;
|
||||
constexpr short DV4 = DV/4;
|
||||
constexpr short NW = N_SIMDWIDTH;
|
||||
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
|
||||
constexpr short SH = 2*C; // shared memory per simdgroup
|
||||
|
||||
const short T = D + nsg*SH; // shared memory size per query in (half)
|
||||
const short T = DK + nsg*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
|
||||
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
|
||||
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
|
||||
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||
o4x4_t lo[D16/NL];
|
||||
// store the result for all queries in local memory (the O matrix from the paper)
|
||||
o4_t lo[DV4/NL];
|
||||
|
||||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
for (short i = tiisg; i < DK4; i += NW) {
|
||||
if (iq1 < args.ne01) {
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
@@ -3648,8 +3672,8 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D16/NL; ++i) {
|
||||
lo[i] = (o4x4_t) 0.0f;
|
||||
for (short i = 0; i < DV4/NL; ++i) {
|
||||
lo[i] = (o4_t) 0.0f;
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
@@ -3674,14 +3698,6 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
q4x4_t mq[D16/NL];
|
||||
|
||||
#pragma unroll(D16/NL)
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
mq[ii/NL] = sq4x4[ii + tx];
|
||||
}
|
||||
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
// pointer to the mask
|
||||
@@ -3713,43 +3729,56 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// Q*K^T
|
||||
{
|
||||
// each simdgroup processes 1 query and 4 (NW/NL) keys
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
|
||||
// each simdgroup processes 1 query and NE (NW/NL) head elements
|
||||
for (short cc = 0; cc < C/NE; ++cc) {
|
||||
qk_t mqk = 0.0f;
|
||||
|
||||
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
|
||||
|
||||
#pragma unroll(D16/NL)
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
#pragma unroll(DK4/NL)
|
||||
for (short ii = 0; ii < DK4; ii += NL) {
|
||||
const short i = ii + tx;
|
||||
|
||||
k4x4_t mk;
|
||||
deq_k(pk + i/nl_k, i%nl_k, mk);
|
||||
k4_t mk;
|
||||
deq_k_t4(pk + i/nl_k, i%nl_k, mk);
|
||||
|
||||
// note: this is less precise than the version below
|
||||
//mqka[0] += dot(mq[ii/NL][0], mk[0]);
|
||||
//mqka[1] += dot(mq[ii/NL][1], mk[1]);
|
||||
//mqka[2] += dot(mq[ii/NL][2], mk[2]);
|
||||
//mqka[3] += dot(mq[ii/NL][3], mk[3]);
|
||||
//mqka[0] += dot(mq[0], mk[0]);
|
||||
//mqka[1] += dot(mq[1], mk[1]);
|
||||
//mqka[2] += dot(mq[2], mk[2]);
|
||||
//mqka[3] += dot(mq[3], mk[3]);
|
||||
|
||||
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
|
||||
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
|
||||
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
|
||||
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
|
||||
//q4x4_t mq = sq4x4[i];
|
||||
//mqka[0] += dot((float4) mq[0], (float4) mk[0]);
|
||||
//mqka[1] += dot((float4) mq[1], (float4) mk[1]);
|
||||
//mqka[2] += dot((float4) mq[2], (float4) mk[2]);
|
||||
//mqka[3] += dot((float4) mq[3], (float4) mk[3]);
|
||||
|
||||
mqk += dot((float4) mk, (float4) sq4[i]);
|
||||
}
|
||||
|
||||
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
|
||||
static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
|
||||
|
||||
// simdgroup reduce
|
||||
// simdgroup reduce (NE = 4)
|
||||
// [ 0 .. 7] -> [ 0]
|
||||
// [ 8 .. 15] -> [ 8]
|
||||
// [16 .. 23] -> [16]
|
||||
// [24 .. 31] -> [24]
|
||||
//mqk += simd_shuffle_down(mqk, 16);
|
||||
//mqk += simd_shuffle_down(mqk, 8);
|
||||
mqk += simd_shuffle_down(mqk, 4);
|
||||
mqk += simd_shuffle_down(mqk, 2);
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
if (NE <= 1) {
|
||||
mqk += simd_shuffle_down(mqk, 16);
|
||||
}
|
||||
if (NE <= 2) {
|
||||
mqk += simd_shuffle_down(mqk, 8);
|
||||
}
|
||||
if (NE <= 4) {
|
||||
mqk += simd_shuffle_down(mqk, 4);
|
||||
}
|
||||
if (NE <= 8) {
|
||||
mqk += simd_shuffle_down(mqk, 2);
|
||||
}
|
||||
if (NE <= 16) {
|
||||
mqk += simd_shuffle_down(mqk, 1);
|
||||
}
|
||||
|
||||
// mqk = mqk*scale + mask*slope
|
||||
if (tx == 0) {
|
||||
@@ -3759,9 +3788,9 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
mqk = args.logit_softcap*precise::tanh(mqk);
|
||||
}
|
||||
|
||||
mqk += sm[4*cc + ty]*slope;
|
||||
mqk += sm[NE*cc + ty]*slope;
|
||||
|
||||
ss[4*cc + ty] = mqk;
|
||||
ss[NE*cc + ty] = mqk;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3784,8 +3813,8 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
ss[tiisg] = vs;
|
||||
|
||||
// O = diag(ms)*O
|
||||
#pragma unroll(D16/NL)
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
#pragma unroll(DV4/NL)
|
||||
for (short ii = 0; ii < DV4; ii += NL) {
|
||||
lo[ii/NL] *= ms;
|
||||
}
|
||||
}
|
||||
@@ -3794,17 +3823,18 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// O = O + (Q*K^T)*V
|
||||
{
|
||||
for (short cc = 0; cc < C/4; ++cc) {
|
||||
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
|
||||
//#pragma unroll(C/NE)
|
||||
for (short cc = 0; cc < C/NE; ++cc) {
|
||||
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
|
||||
|
||||
const s4x4_t ms(ss[4*cc + ty]);
|
||||
const s4_t ms(ss[NE*cc + ty]);
|
||||
|
||||
#pragma unroll(D16/NL)
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
#pragma unroll(DV4/NL)
|
||||
for (short ii = 0; ii < DV4; ii += NL) {
|
||||
const short i = ii + tx;
|
||||
|
||||
v4x4_t mv;
|
||||
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
||||
v4_t mv;
|
||||
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
|
||||
|
||||
lo[ii/NL] += mv*ms;
|
||||
}
|
||||
@@ -3819,7 +3849,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
}
|
||||
|
||||
// simdgroup reduce
|
||||
// simdgroup reduce (NE = 4)
|
||||
// [ 0, 8, 16, 24] -> [ 0]
|
||||
// [ 1, 9, 17, 25] -> [ 1]
|
||||
// [ 2, 10, 18, 26] -> [ 2]
|
||||
@@ -3828,37 +3858,48 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
// [ 5, 13, 21, 29] -> [ 5]
|
||||
// [ 6, 14, 22, 30] -> [ 6]
|
||||
// [ 7, 15, 23, 31] -> [ 7]
|
||||
for (short ii = 0; ii < D16; ii += NL) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
|
||||
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
|
||||
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
|
||||
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
|
||||
for (short ii = 0; ii < DV4; ii += NL) {
|
||||
if (NE > 1) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
|
||||
}
|
||||
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
|
||||
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
|
||||
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
|
||||
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
|
||||
if (NE > 2) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
|
||||
}
|
||||
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
|
||||
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
|
||||
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
|
||||
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
|
||||
if (NE > 4) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
|
||||
}
|
||||
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
|
||||
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
|
||||
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
|
||||
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
|
||||
if (NE > 8) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
|
||||
}
|
||||
|
||||
if (NE > 16) {
|
||||
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
|
||||
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
|
||||
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
|
||||
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// store results to shared memory
|
||||
for (short i = tiisg; i < D16; i += NL) {
|
||||
sr4x4[i] = lo[i/NL];
|
||||
for (short i = tiisg; i < DV4; i += NL) {
|
||||
sr4[i] = lo[i/NL];
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -3885,22 +3926,22 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
|
||||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4x4 * dst44 = (device float4x4 *) dst;
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const float S = ss[0];
|
||||
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3909,34 +3950,54 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
||||
//
|
||||
#define FA_TYPES \
|
||||
half4, half4x4, \
|
||||
half4x4, \
|
||||
half4x4, \
|
||||
float, \
|
||||
half, half4, half4x4, \
|
||||
half4x4
|
||||
half4, \
|
||||
half4, \
|
||||
half4, \
|
||||
float, \
|
||||
half, half4, \
|
||||
half4
|
||||
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
|
||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
|
||||
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
|
||||
#endif
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS
|
||||
ggml-opencl_transpose_16
|
||||
ggml-opencl_transpose_32
|
||||
ggml-opencl_transpose_32_16
|
||||
ggml-opencl_im2col
|
||||
)
|
||||
|
||||
foreach (K ${GGML_OPENCL_KERNELS})
|
||||
|
||||
@@ -224,12 +224,14 @@ struct ggml_backend_opencl_context {
|
||||
cl_program program;
|
||||
cl_program program_1;
|
||||
cl_program program_2;
|
||||
cl_program program_im2col;
|
||||
|
||||
cl_kernel kernel_add, kernel_add_row;
|
||||
cl_kernel kernel_mul, kernel_mul_row;
|
||||
cl_kernel kernel_scale;
|
||||
cl_kernel kernel_silu, kernel_silu_4;
|
||||
cl_kernel kernel_gelu, kernel_gelu_4;
|
||||
cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
|
||||
cl_kernel kernel_relu;
|
||||
cl_kernel kernel_clamp;
|
||||
cl_kernel kernel_norm;
|
||||
@@ -239,6 +241,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
|
||||
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
|
||||
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
|
||||
cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
|
||||
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
|
||||
cl_kernel kernel_mul_mat_f32_f32;
|
||||
cl_kernel kernel_mul_mat_f16_f16;
|
||||
@@ -252,6 +255,7 @@ struct ggml_backend_opencl_context {
|
||||
kernel_mul_mat_q4_0_f32_flat_img_v0;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_im2col_f32, kernel_im2col_f16;
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
// Transpose kernels
|
||||
@@ -708,6 +712,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
|
||||
@@ -722,6 +728,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
|
||||
@@ -769,6 +779,19 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
|
||||
|
||||
// im2col kernels
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src_im2col {
|
||||
#include "ggml-opencl_im2col.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl");
|
||||
#endif
|
||||
backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err));
|
||||
|
||||
// Kernels for Adreno
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1187,6 +1210,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
@@ -1216,14 +1240,26 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return op->ne[3] == 1;
|
||||
case GGML_OP_ROPE: {
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
if (is_mrope && !is_vision) {
|
||||
if (op->src[0]->type == GGML_TYPE_F32 ||
|
||||
op->src[0]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
if (is_vision) {
|
||||
if (op->src[0]->type == GGML_TYPE_F32 ||
|
||||
op->src[0]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -2582,6 +2618,53 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
UNUSED(src1);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
cl_command_queue queue = backend_ctx->queue;
|
||||
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
int n = ggml_nelements(dst);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
kernel = backend_ctx->kernel_gelu_quick_4;
|
||||
n /= 4;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_gelu_quick;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
|
||||
size_t global_work_size[] = {(size_t)n, 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt);
|
||||
|
||||
g_profiling_info.emplace_back();
|
||||
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
|
||||
#else
|
||||
clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL);
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src0->extra);
|
||||
@@ -3980,6 +4063,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sections[4];
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
@@ -3987,23 +4071,23 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
|
||||
|
||||
const bool is_neox = mode & 2;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
}
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
if (!is_neox) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_rope_norm_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_rope_norm_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
};
|
||||
} else {
|
||||
if (is_neox) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_rope_neox_f32;
|
||||
@@ -4014,6 +4098,39 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
};
|
||||
} else if (is_mrope && !is_vision) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_rope_multi_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_rope_multi_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
};
|
||||
} else if (is_vision) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_rope_vision_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_rope_vision_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} else {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
kernel = backend_ctx->kernel_rope_norm_f32;
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
kernel = backend_ctx->kernel_rope_norm_f16;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
};
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
@@ -4049,6 +4166,9 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
|
||||
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
|
||||
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
|
||||
if (is_mrope || is_vision) {
|
||||
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, §ions));
|
||||
}
|
||||
|
||||
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||
@@ -4064,6 +4184,98 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0);
|
||||
GGML_ASSERT(src1);
|
||||
GGML_ASSERT(src1->extra);
|
||||
GGML_ASSERT(dst);
|
||||
GGML_ASSERT(dst->extra);
|
||||
|
||||
// src0 - filter, src1 - input
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
cl_command_queue queue = backend_ctx->queue;
|
||||
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const cl_long IC = src1->ne[is_2D ? 2 : 1];
|
||||
const cl_long IH = is_2D ? src1->ne[1] : 1;
|
||||
const cl_long IW = src1->ne[0];
|
||||
|
||||
const cl_long KH = is_2D ? src0->ne[1] : 1;
|
||||
const cl_long KW = src0->ne[0];
|
||||
|
||||
const cl_long OH = is_2D ? dst->ne[2] : 1;
|
||||
const cl_long OW = dst->ne[1];
|
||||
|
||||
// nb is byte offset, src is type float32
|
||||
const cl_ulong delta_offset = src1->nb[is_2D ? 2 : 1]/4;
|
||||
const cl_long batch = src1->ne[is_2D ? 3 : 2];
|
||||
const cl_ulong batch_offset = src1->nb[is_2D ? 3 : 2]/4;
|
||||
|
||||
const cl_long pelements = OW*KW*KH;
|
||||
const cl_long CHW = IC*KH*KW;
|
||||
|
||||
cl_kernel kernel;
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
kernel = backend_ctx->kernel_im2col_f16;
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_im2col_f32;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &batch_offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &delta_offset));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &IW));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &IH));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &IC));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &OW));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_long), &OH));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_long), &KW));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_long), &KH));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_long), &pelements));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_long), &CHW));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &s0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &s1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &p0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &p1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &d0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &d1));
|
||||
|
||||
const int num_blocks = (pelements + 256 - 1) / 256;
|
||||
size_t global_work_size[] = {(size_t)num_blocks*256, (size_t)OH, (size_t)batch*IC};
|
||||
size_t local_work_size[] = {256, 1, 1};
|
||||
|
||||
#ifdef GGML_OPENCL_PROFILING
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
|
||||
g_profiling_info.emplace_back();
|
||||
populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
|
||||
#else
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
|
||||
#endif
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Op offloading
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -4122,6 +4334,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_gelu;
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_gelu_quick;
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
@@ -4194,6 +4412,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||
}
|
||||
func = ggml_cl_rope;
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
if (!any_on_device) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_cl_im2col;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -404,6 +404,7 @@ kernel void kernel_scale(
|
||||
// gelu
|
||||
//------------------------------------------------------------------------------
|
||||
#define GELU_COEF_A 0.044715f
|
||||
#define GELU_QUICK_COEF -1.702f
|
||||
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f
|
||||
|
||||
kernel void kernel_gelu(
|
||||
@@ -434,6 +435,32 @@ kernel void kernel_gelu_4(
|
||||
dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick(
|
||||
global float * src0,
|
||||
ulong offset0,
|
||||
global float * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float*)((global char*)src0 + offset0);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
float x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_4(
|
||||
global float4 * src0,
|
||||
ulong offset0,
|
||||
global float4 * dst,
|
||||
ulong offsetd
|
||||
) {
|
||||
src0 = (global float4*)((global char*)src0 + offset0);
|
||||
dst = (global float4*)((global char*)dst + offsetd);
|
||||
|
||||
float4 x = src0[get_global_id(0)];
|
||||
dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// silu
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -1325,6 +1352,368 @@ kernel void kernel_rope_neox_f16(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_multi_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_multi_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
if (i0 < n_dims) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
theta_base = pos[i2];
|
||||
}
|
||||
else if (sector >= sections.s0 && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne2 * 1];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + sections.s2) {
|
||||
theta_base = pos[i2 + ne2 * 3];
|
||||
}
|
||||
|
||||
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
} else {
|
||||
global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_vision_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0/2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
const int p = sector;
|
||||
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
} else if (sector >= sections.s0 && sector < sec_w) {
|
||||
const int p = sector - sections.s0;
|
||||
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
}
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_rope_vision_f16(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global int * src1,
|
||||
ulong offset1,
|
||||
global float * src2,
|
||||
ulong offset2,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
float ext_factor,
|
||||
float attn_factor,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
int4 sections
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global int*)((global char*)src1 + offset1);
|
||||
src2 = (global float*)((global char*)src2 + offset2);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
int i3 = get_group_id(2);
|
||||
int i2 = get_group_id(1);
|
||||
int i1 = get_group_id(0);
|
||||
|
||||
float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
|
||||
|
||||
global int * pos = src1;
|
||||
|
||||
const int sect_dims = sections.s0 + sections.s1;
|
||||
const int sec_w = sections.s1 + sections.s0;
|
||||
|
||||
float inv_ndims = -1.f/n_dims;
|
||||
|
||||
for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
|
||||
int ic = i0/2;
|
||||
|
||||
const int sector = (i0/2) % sect_dims;
|
||||
float theta_base = 0.0f;
|
||||
|
||||
if (sector < sections.s0) {
|
||||
const int p = sector;
|
||||
theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
} else if (sector >= sections.s0 && sector < sec_w) {
|
||||
const int p = sector - sections.s0;
|
||||
theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
|
||||
}
|
||||
|
||||
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
|
||||
|
||||
float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
|
||||
|
||||
global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
||||
global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
|
||||
dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// cpy
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
#ifdef cl_khr_fp16
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#elif defined(cl_amd_fp16)
|
||||
#pragma OPENCL EXTENSION cl_amd_fp16 : enable
|
||||
#else
|
||||
#error "Half precision floating point not supportedby OpenCL implementation on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_khr_subgroups
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#elif defined(cl_intel_subgroups)
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#error "Subgroup not supported on your device."
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
// Always use subgroup size of 32 on Intel.
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
// Always use subgroups size of 64 on Adreno.
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#else
|
||||
// TODO: do not know how to choose subgroup size on other GPUs.
|
||||
#error "Selecting subgroup size is not supported on your device."
|
||||
#endif
|
||||
|
||||
kernel void kernel_im2col_f32(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
// threadIdx.x + blockIdx.x * blockDim.x
|
||||
long i = get_global_id(0);
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_im2col_f16(
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global half * dst,
|
||||
ulong offsetd,
|
||||
ulong batch_offset,
|
||||
ulong delta_offset,
|
||||
long IW,
|
||||
long IH,
|
||||
long IC,
|
||||
long OW,
|
||||
long OH,
|
||||
long KW,
|
||||
long KH,
|
||||
long pelements,
|
||||
long CHW,
|
||||
int s0,
|
||||
int s1,
|
||||
int p0,
|
||||
int p1,
|
||||
int d0,
|
||||
int d1
|
||||
) {
|
||||
long i = get_global_id(0);
|
||||
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global half*)((global char*)dst + offsetd);
|
||||
|
||||
long ksize = OW * (KH > 1 ? KW : 1);
|
||||
long kx = i / ksize;
|
||||
long kd = kx * ksize;
|
||||
long ky = (i - kd) / OW;
|
||||
long ix = i % OW;
|
||||
|
||||
long oh = get_group_id(1);
|
||||
long batch = get_group_id(2) / IC;
|
||||
long ic = get_group_id(2) % IC;
|
||||
|
||||
long iiw = ix * s0 + kx * d0 - p0;
|
||||
long iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
long offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
long offset_src = ic * delta_offset + batch * batch_offset;
|
||||
dst[offset_dst] = src1[offset_src + iih * IW + iiw];
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,10 @@
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
typedef SOCKET sockfd_t;
|
||||
@@ -80,6 +84,7 @@ enum rpc_cmd {
|
||||
RPC_CMD_FREE_BUFFER,
|
||||
RPC_CMD_BUFFER_CLEAR,
|
||||
RPC_CMD_SET_TENSOR,
|
||||
RPC_CMD_SET_TENSOR_HASH,
|
||||
RPC_CMD_GET_TENSOR,
|
||||
RPC_CMD_COPY_TENSOR,
|
||||
RPC_CMD_GRAPH_COMPUTE,
|
||||
@@ -89,6 +94,9 @@ enum rpc_cmd {
|
||||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
||||
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
||||
|
||||
struct rpc_msg_get_alloc_size_req {
|
||||
rpc_tensor tensor;
|
||||
};
|
||||
@@ -135,6 +143,10 @@ struct rpc_msg_buffer_clear_req {
|
||||
uint8_t value;
|
||||
};
|
||||
|
||||
struct rpc_msg_set_tensor_hash_rsp {
|
||||
uint8_t result;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_tensor_req {
|
||||
rpc_tensor tensor;
|
||||
uint64_t offset;
|
||||
@@ -187,6 +199,18 @@ struct ggml_backend_rpc_buffer_context {
|
||||
|
||||
// RPC helper functions
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
static uint64_t fnv_hash(const uint8_t * data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
hash ^= data[i];
|
||||
hash *= fnv_prime;
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
|
||||
#ifdef _WIN32
|
||||
if (fd == INVALID_SOCKET) {
|
||||
@@ -483,10 +507,26 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
|
||||
|
||||
static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
if (size > HASH_THRESHOLD) {
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t);
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash));
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
if (response.result) {
|
||||
// the server has the same data, no need to send it
|
||||
return;
|
||||
}
|
||||
}
|
||||
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
|
||||
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
|
||||
std::vector<uint8_t> input(input_size, 0);
|
||||
rpc_tensor rpc_tensor = serialize_tensor(tensor);
|
||||
memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
|
||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||
@@ -772,7 +812,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
|
||||
|
||||
class rpc_server {
|
||||
public:
|
||||
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
||||
rpc_server(ggml_backend_t backend, const char * cache_dir)
|
||||
: backend(backend), cache_dir(cache_dir) {
|
||||
}
|
||||
~rpc_server();
|
||||
|
||||
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||
@@ -782,6 +824,7 @@ public:
|
||||
bool free_buffer(const rpc_msg_free_buffer_req & request);
|
||||
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
|
||||
bool set_tensor(const std::vector<uint8_t> & input);
|
||||
bool set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response);
|
||||
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
|
||||
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
@@ -789,6 +832,7 @@ public:
|
||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||
|
||||
private:
|
||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||
ggml_tensor * create_node(uint64_t id,
|
||||
struct ggml_context * ctx,
|
||||
@@ -797,6 +841,7 @@ private:
|
||||
|
||||
|
||||
ggml_backend_t backend;
|
||||
const char * cache_dir;
|
||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||
};
|
||||
|
||||
@@ -960,11 +1005,85 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
||||
}
|
||||
|
||||
const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
|
||||
if (cache_dir && size > HASH_THRESHOLD) {
|
||||
uint64_t hash = fnv_hash((const uint8_t*)data, size);
|
||||
char hash_str[17];
|
||||
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
||||
// save to cache_dir/hash_str
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
std::ofstream ofs(cache_file, std::ios::binary);
|
||||
ofs.write((const char *)data, size);
|
||||
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, data, offset, size);
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
|
||||
if (!cache_dir) {
|
||||
return false;
|
||||
}
|
||||
char hash_str[17];
|
||||
snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
if (!fs::exists(cache_file)) {
|
||||
return false;
|
||||
}
|
||||
std::ifstream ifs(cache_file, std::ios::binary);
|
||||
ifs.seekg(0, std::ios::end);
|
||||
size_t size = ifs.tellg();
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
data.resize(size);
|
||||
ifs.read((char *)data.data(), size);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::set_tensor_hash(const std::vector<uint8_t> & input, rpc_msg_set_tensor_hash_rsp & response)
|
||||
{
|
||||
// serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) |
|
||||
if (input.size() != sizeof(rpc_tensor) + 16) {
|
||||
return false;
|
||||
}
|
||||
const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
|
||||
uint64_t offset;
|
||||
memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
|
||||
const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset));
|
||||
std::vector<uint8_t> cached_file;
|
||||
if (!get_cached_file(*hash, cached_file)) {
|
||||
response.result = 0;
|
||||
return true;
|
||||
}
|
||||
size_t size = cached_file.size();
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
|
||||
if (tensor == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash);
|
||||
|
||||
// sanitize tensor->data
|
||||
{
|
||||
const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
|
||||
const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
|
||||
|
||||
if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
|
||||
GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, cached_file.data(), offset, size);
|
||||
response.result = 1;
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
@@ -1148,8 +1267,9 @@ rpc_server::~rpc_server() {
|
||||
}
|
||||
}
|
||||
|
||||
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
rpc_server server(backend);
|
||||
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
rpc_server server(backend, cache_dir);
|
||||
while (true) {
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
@@ -1260,6 +1380,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_SET_TENSOR_HASH: {
|
||||
std::vector<uint8_t> input;
|
||||
if (!recv_msg(sockfd, input)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_set_tensor_hash_rsp response;
|
||||
if (!server.set_tensor_hash(input, response)) {
|
||||
return;
|
||||
}
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case RPC_CMD_INIT_TENSOR: {
|
||||
rpc_msg_init_tensor_req request;
|
||||
if (!recv_msg(sockfd, &request,sizeof(request))) {
|
||||
@@ -1335,7 +1469,9 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
||||
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
|
||||
const char * cache_dir,
|
||||
size_t free_mem, size_t total_mem) {
|
||||
std::string host;
|
||||
int port;
|
||||
if (!parse_endpoint(endpoint, host, port)) {
|
||||
@@ -1364,7 +1500,7 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
|
||||
}
|
||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||
fflush(stdout);
|
||||
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
||||
rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
|
||||
printf("Client connection closed\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
@@ -66,41 +66,6 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
|
||||
return sycl_down_blk_size;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const ggml_sycl_op_flatten_t op) try {
|
||||
|
||||
const bool use_src1 = src1 != nullptr;
|
||||
if(use_src1)
|
||||
GGML_ASSERT(strcmp(src1->buffer->buft->iface.get_name(src1->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0);
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = (float *) src0->data;
|
||||
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
|
||||
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
||||
|
||||
ggml_sycl_set_device(ctx.device);
|
||||
queue_ptr main_stream = ctx.stream();
|
||||
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
||||
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
||||
|
||||
// do the computation
|
||||
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||
// print_ggml_tensor("tensor", dst);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
<< ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
|
||||
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) {
|
||||
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
|
||||
for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
|
||||
|
||||
@@ -494,12 +494,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
||||
|
||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
||||
|
||||
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
@@ -757,24 +751,22 @@ struct bin_bcast_sycl {
|
||||
|
||||
template <class op>
|
||||
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
const ggml_tensor *src1, ggml_tensor *dst) {
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
||||
(sycl::half *)dst_dd, main_stream);
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
|
||||
(sycl::half *)dst->data, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
||||
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
||||
op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
||||
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
||||
op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
|
||||
main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
@@ -784,8 +776,4 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
||||
}
|
||||
|
||||
bool gpu_has_xmx(sycl::device &dev);
|
||||
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const ggml_sycl_op_flatten_t op);
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
+185
-273
@@ -509,497 +509,409 @@ static void pad_f32_sycl(const float *x, float *dst, const int ne00,
|
||||
});
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
log_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
step_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float negative_slope;
|
||||
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
||||
|
||||
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), negative_slope, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(dst->src[0]), main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const float sf0 = (float)dst->ne[0]/src0->ne[0];
|
||||
const float sf1 = (float)dst->ne[1]/src0->ne[1];
|
||||
const float sf2 = (float)dst->ne[2]/src0->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/src0->ne[3];
|
||||
const float sf0 = (float)dst->ne[0]/dst->src[0]->ne[0];
|
||||
const float sf1 = (float)dst->ne[1]/dst->src[0]->ne[1];
|
||||
const float sf2 = (float)dst->ne[2]/dst->src[0]->ne[2];
|
||||
const float sf3 = (float)dst->ne[3]/dst->src[0]->ne[3];
|
||||
|
||||
upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
|
||||
upscale_f32_sycl(src0_dd, dst_dd, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2], dst->src[0]->nb[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
|
||||
main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
pad_f32_sycl(src0_dd, dst_dd,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2],
|
||||
dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
|
||||
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt);
|
||||
ggml_sycl_op_sqrt(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin);
|
||||
ggml_sycl_op_sin(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos);
|
||||
ggml_sycl_op_cos(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc);
|
||||
ggml_sycl_op_acc(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu);
|
||||
ggml_sycl_op_gelu(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu);
|
||||
ggml_sycl_op_silu(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick);
|
||||
ggml_sycl_op_gelu_quick(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh);
|
||||
ggml_sycl_op_tanh(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu);
|
||||
ggml_sycl_op_relu(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid);
|
||||
ggml_sycl_op_sigmoid(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid);
|
||||
ggml_sycl_op_hardsigmoid(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish);
|
||||
ggml_sycl_op_hardswish(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp);
|
||||
ggml_sycl_op_exp(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log);
|
||||
ggml_sycl_op_log(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg);
|
||||
ggml_sycl_op_neg(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step);
|
||||
ggml_sycl_op_step(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu);
|
||||
ggml_sycl_op_leaky_relu(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr);
|
||||
ggml_sycl_op_sqr(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale);
|
||||
ggml_sycl_op_upscale(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad);
|
||||
ggml_sycl_op_pad(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
@@ -1007,24 +919,24 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add);
|
||||
ggml_sycl_op_add(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub);
|
||||
ggml_sycl_op_sub(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul);
|
||||
ggml_sycl_op_mul(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div);
|
||||
ggml_sycl_op_div(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
@@ -257,50 +257,54 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_d, const float *src1_d,
|
||||
float *dst_d, const queue_ptr &stream) {
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
|
||||
GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
|
||||
GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
|
||||
|
||||
const int32_t * src1_i32 = (const int32_t *) src1_d;
|
||||
|
||||
switch (src0->type) {
|
||||
const int32_t * src1_i32 = (const int32_t *) dst->src[1]->data;
|
||||
/* TODO: Refactor and remove duplicates */
|
||||
switch (dst->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
|
||||
src1_i32, dst_d, stream);
|
||||
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
|
||||
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
} else {
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
}
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
|
||||
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
default:
|
||||
// TODO: k-quants
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
|
||||
GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(dst->src[0]->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,9 +15,6 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_d, const float *src1_d,
|
||||
float *dst_d, const queue_ptr &stream);
|
||||
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_GETROWS_HPP
|
||||
|
||||
+104
-128
@@ -37,6 +37,7 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include "ggml-sycl/backend.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
@@ -490,6 +491,23 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value,
|
||||
size_t offset, size_t size) {
|
||||
GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__);
|
||||
ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context;
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx->device));
|
||||
auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
|
||||
if (size == 0) {
|
||||
return; // Nothing to do
|
||||
}
|
||||
if (tensor->data == nullptr) {
|
||||
GGML_ABORT("Error: Tensor data pointer is null.\n");
|
||||
}
|
||||
void * target_ptr = static_cast<char *>(tensor->data) + offset;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memset(target_ptr, value, size)));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).wait()));
|
||||
}
|
||||
|
||||
static void ggml_backend_sycl_buffer_reset(ggml_backend_buffer_t buffer) {
|
||||
GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__);
|
||||
if (buffer == nullptr) {
|
||||
@@ -510,7 +528,7 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
|
||||
/* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer,
|
||||
/* .get_base = */ ggml_backend_sycl_buffer_get_base,
|
||||
/* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor,
|
||||
/* .memset_tensor = */ NULL,
|
||||
/* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor,
|
||||
@@ -1970,16 +1988,8 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_d, const float *src1_d,
|
||||
float *dst_d,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src1_d);
|
||||
static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
|
||||
}
|
||||
|
||||
|
||||
@@ -2114,13 +2124,14 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd, const queue_ptr &main_stream) {
|
||||
static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
|
||||
@@ -2131,8 +2142,8 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
const int p0 = opts[5];
|
||||
const int p1 = opts[6];
|
||||
|
||||
const int64_t IH = src0->ne[1];
|
||||
const int64_t IW = src0->ne[0];
|
||||
const int64_t IH = dst->src[0]->ne[1];
|
||||
const int64_t IW = dst->src[0]->ne[0];
|
||||
|
||||
const int64_t N = dst->ne[3];
|
||||
const int64_t OC = dst->ne[2];
|
||||
@@ -2151,163 +2162,125 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
parallel_elements, src0_dd, dst_dd, op,
|
||||
item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
const int64_t ne = ggml_nelements(dst->src[0]);
|
||||
|
||||
sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_I32);
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
int32_t * dst_dd = static_cast<int32_t *>(dst->data);
|
||||
|
||||
argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int nrows0 = ggml_nrows(src0);
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t ne01 = dst->src[0]->ne[1];
|
||||
const int nrows0 = ggml_nrows(dst->src[0]);
|
||||
|
||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
|
||||
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
|
||||
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
|
||||
/*
|
||||
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
|
||||
error codes. The call was replaced with 0. You need to rewrite this code.
|
||||
*/
|
||||
SYCL_CHECK(0);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, dst->op_params, sizeof(float));
|
||||
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
|
||||
clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(dst->src[0]), ctx.stream());
|
||||
/*
|
||||
DPCT1010:88: SYCL uses exceptions to report errors and does not use the
|
||||
error codes. The call was replaced with 0. You need to rewrite this code.
|
||||
*/
|
||||
SYCL_CHECK(0);
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
|
||||
@@ -2677,37 +2650,37 @@ catch (sycl::exception const &exc) {
|
||||
|
||||
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
|
||||
ggml_sycl_op_repeat(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
|
||||
ggml_sycl_op_get_rows(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
|
||||
ggml_sycl_op_norm(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
|
||||
ggml_sycl_op_rms_norm(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||
ggml_sycl_op_l2_norm(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||
ggml_sycl_op_group_norm(ctx, dst);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
@@ -3251,48 +3224,48 @@ catch (sycl::exception const &exc) {
|
||||
}
|
||||
|
||||
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
|
||||
ggml_sycl_op_scale(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
|
||||
ggml_sycl_op_clamp(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
||||
ggml_sycl_op_diag_mask_inf(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
|
||||
ggml_sycl_op_rope(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
|
||||
ggml_sycl_op_pool2d(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
|
||||
ggml_sycl_op_im2col(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
|
||||
ggml_sycl_op_sum(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
|
||||
ggml_sycl_op_sum_rows(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
|
||||
ggml_sycl_op_argsort(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
|
||||
ggml_sycl_op_argmax(ctx, dst);
|
||||
}
|
||||
|
||||
|
||||
@@ -3317,7 +3290,7 @@ catch (sycl::exception const &exc) {
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) try {
|
||||
if (!g_sycl_loaded) return false;
|
||||
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||
@@ -3510,6 +3483,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
}
|
||||
|
||||
return true;
|
||||
} catch (sycl::exception & e) {
|
||||
std::cerr << e.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
|
||||
|
||||
@@ -82,10 +82,9 @@ static void im2col_sycl(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_im2col(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
@@ -115,12 +114,8 @@ void ggml_sycl_op_im2col(
|
||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
|
||||
} else {
|
||||
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream());
|
||||
}
|
||||
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src0_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_im2col(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_IM2COL_HPP
|
||||
|
||||
+35
-47
@@ -397,90 +397,78 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
ggml_tensor* dst, const float* src0_dd,
|
||||
const float* src1_dd, float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
int num_groups = dst->op_params[0];
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
||||
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
GGML_UNUSED(ctx);
|
||||
int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream) {
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
(void)src1_dd;
|
||||
}
|
||||
|
||||
@@ -15,27 +15,12 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
ggml_tensor* dst, const float* src0_dd,
|
||||
const float* src1_dd, float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst,
|
||||
const float* src0_dd, const float* src1_dd,
|
||||
float* dst_dd,
|
||||
const queue_ptr& main_stream);
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
#endif // GGML_SYCL_NORM_HPP
|
||||
|
||||
+20
-25
@@ -192,18 +192,15 @@ static void rope_neox_sycl(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rope(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) {
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t ne01 = dst->src[0]->ne[1];
|
||||
const int64_t nr = ggml_nrows(dst->src[0]);
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
@@ -228,49 +225,47 @@ void ggml_sycl_op_rope(
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1_dd;
|
||||
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
if (dst->src[2] != nullptr) {
|
||||
freq_factors = (const float *) dst->src[2]->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl(
|
||||
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl(
|
||||
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl(
|
||||
(const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
(const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
|
||||
attn_factor, corr_dims, freq_factors, main_stream
|
||||
);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_dd);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rope(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream);
|
||||
void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_ROPE_HPP
|
||||
|
||||
@@ -23,32 +23,64 @@ if (Vulkan_FOUND)
|
||||
../../include/ggml-vulkan.h
|
||||
)
|
||||
|
||||
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
if(NOT DEFINED GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
# Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
|
||||
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
|
||||
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat is supported by glslc")
|
||||
else()
|
||||
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat is supported by glslc")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
|
||||
message(STATUS "GL_KHR_cooperative_matrix support already defined: ${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}")
|
||||
endif()
|
||||
|
||||
if(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
|
||||
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
||||
if(NOT DEFINED GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
|
||||
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF CACHE INTERNAL "Whether coopmat2 is supported by glslc")
|
||||
else()
|
||||
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
|
||||
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON CACHE INTERNAL "Whether coopmat2 is supported by glslc")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "GL_NV_cooperative_matrix2 support already defined: ${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}")
|
||||
endif()
|
||||
|
||||
if(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
endif()
|
||||
|
||||
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
|
||||
# If it's not, there will be an error to stderr.
|
||||
# If it's supported, set a define to indicate that we should compile those shaders
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
||||
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
|
||||
OUTPUT_VARIABLE glslc_output
|
||||
ERROR_VARIABLE glslc_error)
|
||||
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
|
||||
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
|
||||
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
|
||||
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
|
||||
else()
|
||||
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
|
||||
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
|
||||
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
||||
@@ -119,6 +151,8 @@ if (Vulkan_FOUND)
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
|
||||
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build .
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
|
||||
@@ -234,6 +234,8 @@ struct vk_device_struct {
|
||||
bool float_controls_rte_fp16;
|
||||
bool subgroup_add;
|
||||
|
||||
bool integer_dot_product;
|
||||
|
||||
bool subgroup_size_control;
|
||||
uint32_t subgroup_min_size;
|
||||
uint32_t subgroup_max_size;
|
||||
@@ -245,6 +247,12 @@ struct vk_device_struct {
|
||||
uint32_t coopmat_m;
|
||||
uint32_t coopmat_n;
|
||||
uint32_t coopmat_k;
|
||||
|
||||
bool coopmat_int_support;
|
||||
uint32_t coopmat_int_m;
|
||||
uint32_t coopmat_int_n;
|
||||
uint32_t coopmat_int_k;
|
||||
|
||||
bool coopmat2;
|
||||
|
||||
size_t idx;
|
||||
@@ -263,10 +271,10 @@ struct vk_device_struct {
|
||||
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
||||
vk_matmul_pipeline2 pipeline_matmul_f16;
|
||||
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
||||
vk_pipeline pipeline_matmul_split_k_reduce;
|
||||
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
||||
|
||||
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
||||
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
||||
@@ -274,6 +282,9 @@ struct vk_device_struct {
|
||||
|
||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
||||
|
||||
vk_pipeline pipeline_matmul_split_k_reduce;
|
||||
vk_pipeline pipeline_quantize_q8_1;
|
||||
|
||||
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||
@@ -640,6 +651,13 @@ struct vk_op_rwkv_wkv7_push_constants {
|
||||
uint32_t H;
|
||||
};
|
||||
|
||||
struct vk_op_upscale_push_constants {
|
||||
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
};
|
||||
|
||||
// Allow pre-recording command buffers
|
||||
struct vk_staging_memcpy {
|
||||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
||||
@@ -649,13 +667,6 @@ struct vk_staging_memcpy {
|
||||
size_t n;
|
||||
};
|
||||
|
||||
struct vk_op_upscale_push_constants {
|
||||
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||
float sf0; float sf1; float sf2; float sf3;
|
||||
};
|
||||
|
||||
struct vk_context_struct {
|
||||
vk_submission * s;
|
||||
std::vector<vk_sequence> seqs;
|
||||
@@ -1598,6 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
// mulmat
|
||||
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
||||
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
||||
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
||||
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
||||
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
||||
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
||||
@@ -1662,6 +1674,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
|
||||
const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
|
||||
const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
|
||||
const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
|
||||
const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
|
||||
const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
|
||||
const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
||||
const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
||||
const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
|
||||
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
||||
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
||||
@@ -2000,6 +2026,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -2031,6 +2065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
#endif
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
@@ -2056,6 +2100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
||||
#undef CREATE_MM2
|
||||
#undef CREATE_MMQ
|
||||
#undef CREATE_MM
|
||||
} else {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
@@ -2073,6 +2118,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
||||
@@ -2099,6 +2152,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (device->integer_dot_product) {
|
||||
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
||||
}
|
||||
#endif
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
||||
@@ -2132,7 +2195,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
uint32_t rm_stdq = 1;
|
||||
uint32_t rm_kq = 2;
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||
if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
|
||||
if (device->architecture == AMD_GCN) {
|
||||
rm_stdq = 2;
|
||||
rm_kq = 4;
|
||||
}
|
||||
@@ -2266,6 +2329,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
||||
|
||||
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
||||
if (device->subgroup_add && device->subgroup_require_full_support) {
|
||||
@@ -2452,6 +2516,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
bool pipeline_robustness = false;
|
||||
bool coopmat2_support = false;
|
||||
device->coopmat_support = false;
|
||||
device->integer_dot_product = false;
|
||||
|
||||
for (const auto& properties : ext_props) {
|
||||
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
||||
@@ -2477,6 +2542,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
device->integer_dot_product = true;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2490,6 +2560,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
vk::PhysicalDeviceVulkan11Properties vk11_props;
|
||||
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
||||
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
||||
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
@@ -2524,6 +2595,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
if (device->integer_dot_product) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
}
|
||||
|
||||
device->physical_device.getProperties2(&props2);
|
||||
device->properties = props2.properties;
|
||||
device->vendor_id = device->properties.vendorID;
|
||||
@@ -2570,6 +2646,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->coopmat_support = false;
|
||||
}
|
||||
|
||||
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
||||
|
||||
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
||||
|
||||
// Try to find a non-graphics compute queue and transfer-focused queues
|
||||
@@ -2662,6 +2740,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device_extensions.push_back("VK_KHR_maintenance4");
|
||||
}
|
||||
|
||||
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
||||
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
||||
if (device->integer_dot_product) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
||||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
||||
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||
|
||||
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
||||
@@ -2831,6 +2917,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->coopmat_acc_f16_support = true;
|
||||
}
|
||||
}
|
||||
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
||||
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
||||
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
|
||||
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
|
||||
device->coopmat_int_m == 0
|
||||
) {
|
||||
device->coopmat_int_support = true;
|
||||
device->coopmat_int_m = prop.MSize;
|
||||
device->coopmat_int_n = prop.NSize;
|
||||
device->coopmat_int_k = prop.KSize;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2935,25 +3032,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
vk::PhysicalDevice physical_device = devices[dev_num];
|
||||
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
|
||||
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceMaintenance3Properties props3;
|
||||
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
||||
vk::PhysicalDeviceDriverProperties driver_props;
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
subgroup_props.pNext = &driver_props;
|
||||
physical_device.getProperties2(&props2);
|
||||
|
||||
vk_device_architecture arch = get_device_architecture(physical_device);
|
||||
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
||||
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
||||
|
||||
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||
|
||||
bool fp16_storage = false;
|
||||
bool fp16_compute = false;
|
||||
bool coopmat_support = false;
|
||||
bool coopmat2_support = false;
|
||||
bool integer_dot_product = false;
|
||||
|
||||
for (auto properties : ext_props) {
|
||||
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
||||
@@ -2969,27 +3052,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
||||
coopmat2_support = true;
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
||||
integer_dot_product = true;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
||||
|
||||
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
||||
coopmat_support = false;
|
||||
}
|
||||
|
||||
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
||||
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
||||
|
||||
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
|
||||
vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
|
||||
vk::PhysicalDeviceProperties2 props2;
|
||||
vk::PhysicalDeviceMaintenance3Properties props3;
|
||||
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
||||
vk::PhysicalDeviceDriverProperties driver_props;
|
||||
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
||||
props2.pNext = &props3;
|
||||
props3.pNext = &subgroup_props;
|
||||
subgroup_props.pNext = &driver_props;
|
||||
|
||||
// Pointer to the last chain element
|
||||
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
|
||||
|
||||
if (integer_dot_product) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
||||
}
|
||||
|
||||
physical_device.getProperties2(&props2);
|
||||
|
||||
VkPhysicalDeviceFeatures2 device_features2;
|
||||
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
||||
device_features2.pNext = nullptr;
|
||||
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
|
||||
|
||||
VkPhysicalDeviceVulkan11Features vk11_features;
|
||||
vk11_features.pNext = nullptr;
|
||||
@@ -3002,7 +3102,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
vk11_features.pNext = &vk12_features;
|
||||
|
||||
// Pointer to the last chain element
|
||||
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
|
||||
last_struct = (VkBaseOutStructure *)&vk12_features;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
||||
@@ -3014,20 +3114,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
||||
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
||||
}
|
||||
#endif
|
||||
|
||||
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
||||
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
||||
if (integer_dot_product) {
|
||||
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
||||
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
||||
}
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
||||
|
||||
fp16 = fp16 && vk12_features.shaderFloat16;
|
||||
|
||||
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
|
||||
#endif
|
||||
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
|
||||
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
||||
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||
|
||||
integer_dot_product = integer_dot_product
|
||||
&& shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
|
||||
&& shader_integer_dot_product_features.shaderIntegerDotProduct;
|
||||
|
||||
coopmat_support = coopmat_support
|
||||
&& coopmat_features.cooperativeMatrix
|
||||
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
||||
|
||||
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
||||
|
||||
std::string device_name = props2.properties.deviceName.data();
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
|
||||
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
||||
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
||||
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
|
||||
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
||||
|
||||
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
||||
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
||||
@@ -3293,6 +3410,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||
}
|
||||
}
|
||||
|
||||
// MMQ
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
|
||||
|
||||
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return pipelines;
|
||||
}
|
||||
|
||||
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -3585,8 +3713,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
||||
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
||||
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
||||
@@ -4016,8 +4142,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
||||
return split_k;
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
||||
|
||||
if (ctx->device->coopmat2) {
|
||||
// Use large shader when the N dimension is greater than the medium shader's tile size
|
||||
@@ -4042,9 +4168,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
||||
return aligned ? mmp->a_l : mmp->l;
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
||||
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
||||
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
||||
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
|
||||
}
|
||||
|
||||
static void ggml_vk_matmul(
|
||||
@@ -4054,7 +4180,7 @@ static void ggml_vk_matmul(
|
||||
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
||||
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
||||
uint32_t padded_n) {
|
||||
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
||||
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
if (split_k == 1) {
|
||||
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
||||
@@ -4072,7 +4198,7 @@ static void ggml_vk_matmul(
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
|
||||
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
||||
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
||||
|
||||
if (ctx->device->coopmat2) {
|
||||
@@ -4214,6 +4340,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
||||
switch(type) {
|
||||
case GGML_TYPE_Q8_1:
|
||||
return ctx->device->pipeline_quantize_q8_1;
|
||||
default:
|
||||
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
|
||||
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
|
||||
}
|
||||
|
||||
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||
@@ -4265,10 +4410,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
|
||||
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
||||
|
||||
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
||||
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
||||
|
||||
// Check for mmq first
|
||||
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
|
||||
|
||||
if (mmp == nullptr) {
|
||||
// Fall back to f16 dequant mul mat
|
||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
|
||||
quantize_y = false;
|
||||
}
|
||||
|
||||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
|
||||
|
||||
if (qx_needs_dequant) {
|
||||
// Fall back to dequant + f16 mulmat
|
||||
@@ -4278,13 +4432,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
// Not implemented
|
||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||
|
||||
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
|
||||
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
||||
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
|
||||
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
||||
|
||||
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
||||
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
||||
const int x_ne = ne01 * ne00;
|
||||
const int y_ne = padded_n * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
@@ -4294,11 +4448,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
||||
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
||||
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
||||
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
||||
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
||||
const uint64_t d_sz = sizeof(float) * d_ne;
|
||||
|
||||
vk_pipeline to_fp16_vk_0 = nullptr;
|
||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||
vk_pipeline to_q8_1 = nullptr;
|
||||
|
||||
if (x_non_contig) {
|
||||
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
|
||||
@@ -4313,6 +4468,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||
|
||||
if (quantize_y) {
|
||||
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||
}
|
||||
|
||||
if (dryrun) {
|
||||
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
||||
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
||||
@@ -4326,7 +4485,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
||||
ctx->prealloc_size_x = x_sz_upd;
|
||||
}
|
||||
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
|
||||
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
|
||||
ctx->prealloc_size_y = y_sz_upd;
|
||||
}
|
||||
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
|
||||
@@ -4341,6 +4500,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (qy_needs_dequant) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
|
||||
}
|
||||
if (quantize_y) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
|
||||
}
|
||||
if (split_k > 1) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
|
||||
}
|
||||
@@ -4376,6 +4538,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (qy_needs_dequant) {
|
||||
d_Y = ctx->prealloc_y;
|
||||
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
||||
} else if (quantize_y) {
|
||||
d_Y = ctx->prealloc_y;
|
||||
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
|
||||
} else {
|
||||
d_Y = d_Qy;
|
||||
y_buf_offset = qy_buf_offset;
|
||||
@@ -4392,6 +4557,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
if (y_non_contig) {
|
||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||
}
|
||||
if (quantize_y) {
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
||||
}
|
||||
|
||||
uint32_t stride_batch_x = ne00*ne01;
|
||||
uint32_t stride_batch_y = ne10*ne11;
|
||||
@@ -4400,7 +4568,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
||||
}
|
||||
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
||||
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
|
||||
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
||||
}
|
||||
|
||||
@@ -6929,6 +7097,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx->device->need_compiles) {
|
||||
ggml_vk_load_shaders(ctx->device);
|
||||
}
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
||||
|
||||
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
@@ -7177,6 +7349,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
||||
|
||||
if (ctx->device->need_compiles) {
|
||||
ggml_vk_load_shaders(ctx->device);
|
||||
}
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
||||
|
||||
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
|
||||
@@ -7236,66 +7412,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
||||
free(x_chk);
|
||||
}
|
||||
|
||||
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
|
||||
// This does not work without ggml q8_1 quantization support
|
||||
//
|
||||
// typedef uint16_t ggml_half;
|
||||
// typedef uint32_t ggml_half2;
|
||||
//
|
||||
// #define QK8_1 32
|
||||
// typedef struct {
|
||||
// union {
|
||||
// struct {
|
||||
// ggml_half d; // delta
|
||||
// ggml_half s; // d * sum(qs[i])
|
||||
// } GGML_COMMON_AGGR_S;
|
||||
// ggml_half2 ds;
|
||||
// } GGML_COMMON_AGGR_U;
|
||||
// int8_t qs[QK8_1]; // quants
|
||||
// } block_q8_1;
|
||||
//
|
||||
// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
|
||||
// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
|
||||
// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
|
||||
//
|
||||
// const size_t x_sz = sizeof(float) * ne;
|
||||
// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
||||
// float * x = (float *) malloc(x_sz);
|
||||
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
|
||||
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
|
||||
// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
//
|
||||
// for (size_t i = 0; i < ne; i++) {
|
||||
// x[i] = rand() / (float)RAND_MAX;
|
||||
// }
|
||||
//
|
||||
// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
|
||||
//
|
||||
// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
||||
//
|
||||
// if (ctx->device->need_compiles) {
|
||||
// ggml_vk_load_shaders(ctx->device);
|
||||
// }
|
||||
//
|
||||
// ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
||||
//
|
||||
// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
|
||||
//
|
||||
// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||
// ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
|
||||
// ggml_vk_ctx_end(subctx);
|
||||
//
|
||||
// auto begin = std::chrono::high_resolution_clock::now();
|
||||
//
|
||||
// ggml_vk_submit(subctx, ctx->fence);
|
||||
// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
|
||||
// ctx->device->device.resetFences({ ctx->fence });
|
||||
//
|
||||
// auto end = std::chrono::high_resolution_clock::now();
|
||||
//
|
||||
// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
||||
// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
|
||||
//
|
||||
// ggml_vk_quantize_data(x, qx_res, ne, quant);
|
||||
//
|
||||
// int first_err = -1;
|
||||
//
|
||||
// for (size_t i = 0; i < ne / 32; i++) {
|
||||
// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
|
||||
//
|
||||
// if (first_err < 0 && error > 0.1) {
|
||||
// first_err = i;
|
||||
// }
|
||||
//
|
||||
// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
|
||||
//
|
||||
// if (first_err < 0 && error > 0.1) {
|
||||
// first_err = i;
|
||||
// }
|
||||
//
|
||||
// for (size_t j = 0; j < 32; j++) {
|
||||
// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
|
||||
//
|
||||
// if (first_err < 0 && error > 1) {
|
||||
// first_err = i;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
|
||||
//
|
||||
// if (first_err != -1) {
|
||||
// std::cerr << "first_error = " << first_err << std::endl;
|
||||
// std::cerr << "Actual result: " << std::endl << std::endl;
|
||||
// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
||||
// for (size_t j = 0; j < 32; j++) {
|
||||
// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
|
||||
// }
|
||||
// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
|
||||
// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
||||
// for (size_t j = 0; j < 32; j++) {
|
||||
// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
|
||||
// }
|
||||
// std::cerr << std::endl;
|
||||
// }
|
||||
//
|
||||
// ggml_vk_destroy_buffer(x_buf);
|
||||
// ggml_vk_destroy_buffer(qx_buf);
|
||||
//
|
||||
// free(x);
|
||||
// free(qx);
|
||||
// free(qx_res);
|
||||
// }
|
||||
|
||||
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
|
||||
const size_t x_ne = m * k * batch;
|
||||
const size_t y_ne = k * n * batch;
|
||||
const size_t d_ne = m * n * batch;
|
||||
|
||||
vk_matmul_pipeline2 * pipelines;
|
||||
|
||||
if (mmq) {
|
||||
pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
|
||||
} else {
|
||||
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
|
||||
}
|
||||
|
||||
const bool fp16acc = ctx->device->fp16;
|
||||
|
||||
vk_pipeline p;
|
||||
std::string shname;
|
||||
if (shader_size == 0) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
|
||||
p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
|
||||
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
||||
} else if (shader_size == 1) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
|
||||
p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
|
||||
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
||||
} else if (shader_size == 2) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
|
||||
p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
|
||||
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
||||
} else {
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
|
||||
const size_t kpad = ggml_vk_align_size(k, p->align);
|
||||
const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
|
||||
|
||||
if (k != kpad) {
|
||||
if (mmq || k != kpad) {
|
||||
if (shader_size == 0) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
|
||||
p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
|
||||
shname = std::string(ggml_type_name(quant)) + "_S";
|
||||
} else if (shader_size == 1) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
|
||||
p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
|
||||
shname = std::string(ggml_type_name(quant)) + "_M";
|
||||
} else if (shader_size == 2) {
|
||||
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
|
||||
p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
|
||||
shname = std::string(ggml_type_name(quant)) + "_L";
|
||||
} else {
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (p == nullptr) {
|
||||
std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t x_sz = sizeof(float) * x_ne;
|
||||
const size_t y_sz = sizeof(float) * y_ne;
|
||||
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
||||
const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
|
||||
const size_t d_sz = sizeof(float) * d_ne;
|
||||
float * x = (float *) malloc(x_sz);
|
||||
float * y = (float *) malloc(y_sz);
|
||||
void * qx = malloc(qx_sz);
|
||||
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
float * d = (float *) malloc(d_sz);
|
||||
float * d_chk = (float *) malloc(d_sz);
|
||||
|
||||
for (size_t i = 0; i < x_ne; i++) {
|
||||
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
||||
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
||||
// x[i] = i % k;
|
||||
}
|
||||
|
||||
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
||||
|
||||
for (size_t i = 0; i < y_ne; i++) {
|
||||
// y[i] = rand() / (float)RAND_MAX;
|
||||
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
||||
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
||||
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
||||
// y[i] = i % k;
|
||||
}
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
|
||||
@@ -7310,6 +7618,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
||||
}
|
||||
}
|
||||
if (mmq) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
|
||||
}
|
||||
|
||||
if (ctx->device->need_compiles) {
|
||||
ggml_vk_load_shaders(ctx->device);
|
||||
}
|
||||
|
||||
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
||||
|
||||
@@ -7318,13 +7633,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
|
||||
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
||||
ggml_vk_ctx_begin(ctx->device, subctx);
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
|
||||
m, n, k,
|
||||
k, k, m, k*m, k*n, m*n,
|
||||
split_k, batch, batch, batch, 1, 1, n
|
||||
);
|
||||
if (mmq) {
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
||||
m, n, k,
|
||||
k, k, m, k*m, k*n, m*n,
|
||||
split_k, batch, batch, batch, 1, 1, n
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < num_it; i++) {
|
||||
ggml_vk_matmul(
|
||||
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
||||
m, n, k,
|
||||
k, k, m, k*m, k*n, m*n,
|
||||
split_k, batch, batch, batch, 1, 1, n
|
||||
);
|
||||
}
|
||||
}
|
||||
ggml_vk_ctx_end(subctx);
|
||||
|
||||
@@ -7382,7 +7709,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
|
||||
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
|
||||
|
||||
std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
||||
std::cerr << "TEST dequant matmul " << shname;
|
||||
if (mmq) {
|
||||
std::cerr << " mmq";
|
||||
}
|
||||
std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
||||
|
||||
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
||||
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
||||
@@ -7392,6 +7723,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
std::cerr << "Expected result: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
std::cerr << "src0: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "src1: " << std::endl << std::endl;
|
||||
ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
|
||||
|
||||
if (split_k > 1) {
|
||||
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
|
||||
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
|
||||
@@ -7414,6 +7751,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
||||
|
||||
ggml_vk_destroy_buffer(qx_buf);
|
||||
ggml_vk_destroy_buffer(y_buf);
|
||||
ggml_vk_destroy_buffer(qy_buf);
|
||||
ggml_vk_destroy_buffer(d_buf);
|
||||
|
||||
free(x);
|
||||
@@ -7446,7 +7784,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
||||
128, 49, 49,
|
||||
4096, 49, 4096,
|
||||
};
|
||||
const size_t num_it = 100;
|
||||
const size_t num_it = 1;
|
||||
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
|
||||
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
|
||||
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
|
||||
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
|
||||
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
|
||||
|
||||
abort();
|
||||
|
||||
for (size_t i = 0; i < vals.size(); i += 3) {
|
||||
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
||||
@@ -8764,6 +9120,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
||||
// different head sizes of K and V are not supported yet
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
@@ -9254,7 +9614,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
const float *params = (const float *)tensor->op_params;
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
||||
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
@@ -9271,7 +9631,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_UPSCALE) {
|
||||
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
||||
} else if (tensor->op == GGML_OP_SCALE) {
|
||||
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
||||
} else if (tensor->op == GGML_OP_SQR) {
|
||||
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_SIN) {
|
||||
@@ -9279,7 +9640,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_COS) {
|
||||
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
||||
} else if (tensor->op == GGML_OP_CLAMP) {
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
||||
} else if (tensor->op == GGML_OP_PAD) {
|
||||
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
|
||||
} else if (tensor->op == GGML_OP_REPEAT) {
|
||||
@@ -9293,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
} else if (tensor->op == GGML_OP_NORM) {
|
||||
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
||||
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
||||
const float * float_params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
|
||||
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
||||
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
||||
@@ -9306,14 +9669,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||
if (src1 != nullptr) {
|
||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
const float * params = (const float *)tensor->op_params;
|
||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
|
||||
} else {
|
||||
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
||||
}
|
||||
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
||||
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
||||
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
|
||||
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
|
||||
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
||||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
find_package (Threads REQUIRED)
|
||||
|
||||
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
endif()
|
||||
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
endif()
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
@@ -212,7 +212,7 @@ void main() {
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[WNITER * TN];
|
||||
FLOAT_TYPE cache_b[TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
@@ -744,16 +744,14 @@ void main() {
|
||||
}
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint j = 0; j < TN; j++) {
|
||||
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,444 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
#endif
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint M;
|
||||
uint N;
|
||||
uint K;
|
||||
uint stride_a;
|
||||
uint stride_b;
|
||||
uint stride_d;
|
||||
|
||||
uint batch_stride_a;
|
||||
uint batch_stride_b;
|
||||
uint batch_stride_d;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint nei0;
|
||||
uint nei1;
|
||||
uint nbi1;
|
||||
uint ne11;
|
||||
#else
|
||||
uint k_split;
|
||||
uint ne02;
|
||||
uint ne12;
|
||||
uint broadcast2;
|
||||
uint broadcast3;
|
||||
#endif
|
||||
} p;
|
||||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
// layout (constant_id = 3) const uint BK = 32;
|
||||
layout (constant_id = 4) const uint WM = 32;
|
||||
layout (constant_id = 5) const uint WN = 32;
|
||||
layout (constant_id = 6) const uint WMITER = 2;
|
||||
layout (constant_id = 7) const uint TM = 4;
|
||||
layout (constant_id = 8) const uint TN = 2;
|
||||
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
||||
layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#define BK 32
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK / 4 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK / 4 + 1)
|
||||
#endif
|
||||
|
||||
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
|
||||
|
||||
#ifndef COOPMAT
|
||||
#if QUANT_AUXF == 1
|
||||
shared FLOAT_TYPE buf_a_dm[BM];
|
||||
#else
|
||||
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
|
||||
#endif
|
||||
#endif
|
||||
|
||||
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
|
||||
#ifndef COOPMAT
|
||||
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
||||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[3072];
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mmq_funcs.comp"
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||
#else
|
||||
const uint batch_idx = gl_GlobalInvocationID.z;
|
||||
|
||||
const uint i13 = batch_idx / p.ne12;
|
||||
const uint i12 = batch_idx % p.ne12;
|
||||
|
||||
const uint i03 = i13 / p.broadcast3;
|
||||
const uint i02 = i12 / p.broadcast2;
|
||||
|
||||
const uint batch_idx_a = i03 * p.ne02 + i02;
|
||||
#endif
|
||||
|
||||
const uint blocks_m = (p.M + BM - 1) / BM;
|
||||
const uint ir = gl_WorkGroupID.x % blocks_m;
|
||||
const uint ik = gl_WorkGroupID.x / blocks_m;
|
||||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
|
||||
const uint WSUBM = WM / WMITER;
|
||||
const uint WSUBN = WN / WNITER;
|
||||
|
||||
#ifdef COOPMAT
|
||||
const uint warp_i = gl_SubgroupID;
|
||||
|
||||
const uint tiw = gl_SubgroupInvocationID;
|
||||
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
|
||||
const uint storestride = WARP / TM;
|
||||
const uint store_r = tiw % TM;
|
||||
const uint store_c = tiw / TM;
|
||||
#else
|
||||
const uint warp_i = gl_LocalInvocationID.x / WARP;
|
||||
|
||||
const uint tiw = gl_LocalInvocationID.x % WARP;
|
||||
|
||||
const uint tiwr = tiw % (WSUBM / TM);
|
||||
const uint tiwc = tiw / (WSUBM / TM);
|
||||
#endif
|
||||
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
||||
|
||||
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
|
||||
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
uint _ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint start_k = 0;
|
||||
const uint end_k = p.K;
|
||||
#else
|
||||
const uint start_k = ik * p.k_split;
|
||||
const uint end_k = min(p.K, (ik + 1) * p.k_split);
|
||||
#endif
|
||||
|
||||
uint pos_a_ib = (
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * p.batch_stride_a +
|
||||
#else
|
||||
batch_idx_a * p.batch_stride_a +
|
||||
#endif
|
||||
ir * BM * p.stride_a + start_k) / BK;
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_b_ib = 0;
|
||||
#else
|
||||
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
|
||||
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
|
||||
}
|
||||
#else
|
||||
int32_t cache_a_qs[WMITER * TM * BK / 4];
|
||||
|
||||
int32_t cache_b_qs[TN * BK / 4];
|
||||
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
FLOAT_TYPE cache_a_dm[TM];
|
||||
#else
|
||||
FLOAT_TYPE_VEC2 cache_a_dm[TM];
|
||||
#endif
|
||||
|
||||
FLOAT_TYPE_VEC2 cache_b_ds[TN];
|
||||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
|
||||
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
|
||||
const uint iqs = loadr_a;
|
||||
const uint buf_ib = loadc_a + l;
|
||||
|
||||
// Should ds be gated to a single thread?
|
||||
if (iqs == 0) {
|
||||
#if QUANT_AUXF == 1
|
||||
buf_a_dm[buf_ib] = get_d(ib);
|
||||
#else
|
||||
buf_a_dm[buf_ib] = get_dm(ib);
|
||||
#endif
|
||||
}
|
||||
#if QUANT_R == 1
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
|
||||
#else
|
||||
const i32vec2 vals = repack(ib, iqs);
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
|
||||
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
// Should ds be gated to a single thread?
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
|
||||
}
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a_ib += 1;
|
||||
pos_b_ib += 1;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint ib_a = warp_r * WM + cm_row * TM;
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
|
||||
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint ib_b = warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
// TODO: only cache values that are actually needed
|
||||
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
|
||||
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
|
||||
}
|
||||
|
||||
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
|
||||
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
|
||||
}
|
||||
|
||||
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
|
||||
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
|
||||
cache_b_ds[cc] = buf_b_ds[ib];
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint cache_a_idx = wsir * TM + cr;
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
int32_t q_sum = 0;
|
||||
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
|
||||
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
|
||||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
const uint dr = ir * BM + warp_r * WM;
|
||||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
#ifndef MUL_MAT_ID
|
||||
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#ifdef MUL_MAT_ID
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||
|
||||
if (is_aligned && is_in_bounds) {
|
||||
// Full coopMat is within bounds and stride_d is aligned with 16B
|
||||
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
|
||||
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
} else if (is_in_bounds) {
|
||||
// Full coopMat is within bounds, but stride_d is not aligned
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
|
||||
// Partial coopMat is within bounds
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
|
||||
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
|
||||
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
#else
|
||||
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
|
||||
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
|
||||
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
|
||||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
}
|
||||
#endif // MUL_MAT_ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // COOPMAT
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
// Each iqs value maps to a 32-bit integer
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
return i32vec2( vui & 0x0F0F0F0F,
|
||||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_0)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
|
||||
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]);
|
||||
const uint32_t vui = pack32(quants);
|
||||
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
i32vec2 repack(uint ib, uint iqs) {
|
||||
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
|
||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||
|
||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||
|
||||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
int32_t repack(uint ib, uint iqs) {
|
||||
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
|
||||
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
|
||||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||
FLOAT_TYPE get_d(uint ib) {
|
||||
return FLOAT_TYPE(data_a[ib].d);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,77 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(constant_id = 0) const uint GROUP_SIZE = 32;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {vec4 data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
|
||||
|
||||
shared float shmem[GROUP_SIZE];
|
||||
|
||||
void quantize() {
|
||||
const uint wgid = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
// Each thread handles a vec4, so 8 threads handle a block
|
||||
const uint blocks_per_group = GROUP_SIZE / 8;
|
||||
|
||||
const uint block_in_wg = tid / 8;
|
||||
|
||||
const uint ib = wgid * blocks_per_group + block_in_wg;
|
||||
const uint iqs = tid % 8;
|
||||
|
||||
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint a_idx = ib * 8 + iqs;
|
||||
|
||||
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
|
||||
const vec4 abs_vals = abs(vals);
|
||||
|
||||
// Find absolute max for each block
|
||||
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
shmem[tid] = max(shmem[tid], shmem[tid + s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
const float amax = shmem[block_in_wg * 8];
|
||||
const float d = amax / 127.0;
|
||||
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
|
||||
vals = round(vals * d_inv);
|
||||
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
|
||||
barrier();
|
||||
|
||||
// Calculate the sum for each block
|
||||
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
shmem[tid] += shmem[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (iqs == 0) {
|
||||
const float sum = shmem[tid];
|
||||
|
||||
data_b[ib].ds = f16vec2(vec2(d, sum * d));
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
quantize();
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
#version 460
|
||||
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
|
||||
void main()
|
||||
{
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
#if !defined(GGML_TYPES_COMP)
|
||||
#define GGML_TYPES_COMP
|
||||
|
||||
@@ -51,6 +50,7 @@ struct block_q4_0_packed16
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define QUANT_K QUANT_K_Q4_0
|
||||
#define QUANT_R QUANT_R_Q4_0
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q4_0
|
||||
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||
#endif
|
||||
@@ -72,11 +72,19 @@ struct block_q4_1_packed16
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
struct block_q4_1_packed32
|
||||
{
|
||||
f16vec2 dm;
|
||||
uint32_t qs[16/4];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q4_1)
|
||||
#define QUANT_K QUANT_K_Q4_1
|
||||
#define QUANT_R QUANT_R_Q4_1
|
||||
#define QUANT_AUXF 2
|
||||
#define A_TYPE block_q4_1
|
||||
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_1_packed32
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q5_0 32
|
||||
@@ -99,6 +107,7 @@ struct block_q5_0_packed16
|
||||
#if defined(DATA_A_Q5_0)
|
||||
#define QUANT_K QUANT_K_Q5_0
|
||||
#define QUANT_R QUANT_R_Q5_0
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q5_0
|
||||
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||
#endif
|
||||
@@ -122,11 +131,20 @@ struct block_q5_1_packed16
|
||||
uint16_t qs[16/2];
|
||||
};
|
||||
|
||||
struct block_q5_1_packed32
|
||||
{
|
||||
f16vec2 dm;
|
||||
uint qh;
|
||||
uint32_t qs[16/4];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q5_1)
|
||||
#define QUANT_K QUANT_K_Q5_1
|
||||
#define QUANT_R QUANT_R_Q5_1
|
||||
#define QUANT_AUXF 2
|
||||
#define A_TYPE block_q5_1
|
||||
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||
#define A_TYPE_PACKED32 block_q5_1_packed32
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_0 32
|
||||
@@ -142,14 +160,40 @@ struct block_q8_0_packed16
|
||||
float16_t d;
|
||||
int16_t qs[32/2];
|
||||
};
|
||||
struct block_q8_0_packed32
|
||||
{
|
||||
float16_t d;
|
||||
int32_t qs[32/4];
|
||||
};
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define QUANT_K QUANT_K_Q8_0
|
||||
#define QUANT_R QUANT_R_Q8_0
|
||||
#define QUANT_AUXF 1
|
||||
#define A_TYPE block_q8_0
|
||||
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||
#define A_TYPE_PACKED32 block_q8_0_packed32
|
||||
#endif
|
||||
|
||||
#define QUANT_K_Q8_1 32
|
||||
#define QUANT_R_Q8_1 1
|
||||
|
||||
struct block_q8_1
|
||||
{
|
||||
f16vec2 ds;
|
||||
int8_t qs[32];
|
||||
};
|
||||
struct block_q8_1_packed16
|
||||
{
|
||||
f16vec2 ds;
|
||||
int16_t qs[16];
|
||||
};
|
||||
struct block_q8_1_packed32
|
||||
{
|
||||
f16vec2 ds;
|
||||
int32_t qs[8];
|
||||
};
|
||||
|
||||
// K-quants
|
||||
#define QUANT_K_Q2_K 256
|
||||
|
||||
|
||||
@@ -295,7 +295,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
|
||||
std::map<std::string, std::string> base_dict = {
|
||||
{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
|
||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
||||
};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
@@ -313,9 +316,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
base_dict["COOPMAT"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
|
||||
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
@@ -339,14 +340,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,6 +465,7 @@ void process_shaders() {
|
||||
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
||||
|
||||
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
|
||||
+1
-1
@@ -4369,7 +4369,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
}
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||
int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
float params[] = { scale, max_bias, logit_softcap };
|
||||
|
||||
@@ -286,6 +286,8 @@ class MODEL_ARCH(IntEnum):
|
||||
GRANITE_MOE = auto()
|
||||
CHAMELEON = auto()
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
PLM = auto()
|
||||
BAILINGMOE = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@@ -488,6 +490,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
MODEL_ARCH.CHAMELEON: "chameleon",
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@@ -1464,6 +1468,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.PLM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM : [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
@@ -1651,6 +1669,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_V,
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT,
|
||||
],
|
||||
MODEL_ARCH.BAILINGMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@@ -1703,6 +1740,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.BAILINGMOE: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
||||
@@ -29,6 +29,7 @@ class TensorNameMap:
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv6
|
||||
"model.embeddings", # rwkv7
|
||||
"model.word_embeddings", # bailingmoe
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
|
||||
@@ -108,6 +108,8 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
||||
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
||||
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
@@ -1265,6 +1267,10 @@ extern "C" {
|
||||
float tau,
|
||||
float eta);
|
||||
|
||||
/// @details Intializes a GBNF grammar, see grammars/README.md for details.
|
||||
/// @param vocab The vocabulary that this grammar will be used with.
|
||||
/// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
|
||||
/// @param grammar_root The name of the start symbol for the grammar.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
|
||||
const struct llama_vocab * vocab,
|
||||
const char * grammar_str,
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg id="Layer_1" xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1500 500">
|
||||
<!-- Generator: Adobe Illustrator 29.3.1, SVG Export Plug-In . SVG Version: 2.1.0 Build 151) -->
|
||||
<defs>
|
||||
<style>
|
||||
.st0 {
|
||||
fill: #ff8236;
|
||||
}
|
||||
|
||||
.st1 {
|
||||
fill: #fff;
|
||||
}
|
||||
|
||||
.st2 {
|
||||
fill: #1b1f20;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<rect class="st2" width="1500" height="500" rx="16" ry="16"/>
|
||||
<g>
|
||||
<path class="st1" d="M749.4,353.8l5.4-204.1,20.4-.8,45.1,98.8,42.5-99h19l6.5,205h-38l-2-98-24.9,61.4c-1,1.3-8,1.3-9-1l-25.6-61.4-1.5,99h-38Z"/>
|
||||
<path class="st1" d="M727.5,240.1c-10.8-27.1-53.1-24.5-75.3-14.7l3.1,28.4c9.2-1.9,30-8,37.5-1,.9.9,3.5,5.7,3.5,6.5v16.5c-31.8-17.2-54.5,6.1-54.4,38.5,0,36.5,28.4,57.3,56.4,27.5v12h32v-104.5c0-.5-2.4-8-2.8-9.2ZM696.4,327.8c-8.4,1.7-15.4,2.9-19.2-6.3-5.8-14,.6-37.9,19.2-27.2v33.5Z"/>
|
||||
<path class="st1" d="M899.4,353.8l47.6-205.1h30.3c0,.1,47,205.1,47,205.1h-38l-7.9-33.6h-34.1l-7.9,33.6h-37ZM951.4,285.8h20l-10.5-56-9.5,56Z"/>
|
||||
<polygon class="st1" points="490.4 148.8 490.4 317.3 491.9 318.8 534.4 318.8 534.4 353.8 451.4 353.8 451.4 150.3 452.9 148.8 490.4 148.8"/>
|
||||
<polygon class="st1" points="589.4 148.8 589.4 318.8 633.4 318.8 633.4 353.8 550.4 353.8 550.4 148.8 589.4 148.8"/>
|
||||
<g>
|
||||
<path class="st0" d="M1163.3,226.8l-13.5,24c-17.8-13.7-44.2-15.7-62-1-28.7,23.7-26.7,78.5,18,78.8,12.5,0,23.1-5.9,34.5-9.8l6,23.9c-10.1,4.7-20.4,9.5-31.5,11-101.2,13.8-95.4-132.3-3.9-139.9,19.2-1.6,36.1,3.4,52.5,13Z"/>
|
||||
<path class="st0" d="M1093.4,203.8c-15.4,4.6-29.7,13.1-40.5,25-2-24.2,3.4-73.1,30.3-82.7,4-1.4,17.7-4.9,17.3,2.2s-9.9,19.3-12.2,25.9c-4,11.6-.3,19.6,5.2,29.7Z"/>
|
||||
<polygon class="st0" points="1131.4 258.8 1131.4 276.8 1147.4 276.8 1147.4 290.8 1131.4 290.8 1131.4 307.8 1116.4 307.8 1116.4 290.8 1099.4 290.8 1099.4 276.8 1114.9 276.8 1116.4 275.3 1116.4 258.8 1131.4 258.8"/>
|
||||
<polygon class="st0" points="1186.4 258.8 1186.4 275.3 1187.9 276.8 1203.4 276.8 1203.4 290.8 1186.4 290.8 1186.4 307.8 1171.4 307.8 1171.4 290.8 1155.4 290.8 1155.4 276.8 1171.4 276.8 1171.4 258.8 1186.4 258.8"/>
|
||||
<path class="st0" d="M1142.3,156.9c2,3-9.3,15.9-11.1,19.2-5.2,9.8-1.7,15.4,2.2,24.7-11.3-1.7-21.8-.3-33,1,2.5-21.5,14.6-52.8,41.9-44.9Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.3 KiB |
+16
-3
@@ -69,7 +69,11 @@ while read c; do
|
||||
git format-patch -U${ctx} -k $c~1..$c --stdout -- \
|
||||
CMakeLists.txt \
|
||||
src/CMakeLists.txt \
|
||||
cmake/FindSIMD.cmake \
|
||||
cmake/BuildTypes.cmake \
|
||||
cmake/GitVars.cmake \
|
||||
cmake/common.cmake \
|
||||
cmake/ggml-config.cmake.in \
|
||||
src/ggml-cpu/cmake/FindSIMD.cmake \
|
||||
src/ggml*.h \
|
||||
src/ggml*.c \
|
||||
src/ggml*.cpp \
|
||||
@@ -121,7 +125,12 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||
#
|
||||
# CMakelists.txt -> ggml/CMakeLists.txt
|
||||
# src/CMakeLists.txt -> ggml/src/CMakeLists.txt
|
||||
# cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake
|
||||
|
||||
# cmake/BuildTypes.cmake -> ggml/cmake/BuildTypes.cmake
|
||||
# cmake/GitVars.cmake -> ggml/cmake/GitVars.cmake
|
||||
# cmake/common.cmake -> ggml/cmake/common.cmake
|
||||
# cmake/ggml-config.cmake.in -> ggml/cmake/ggml-config.cmake.in
|
||||
# src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake
|
||||
#
|
||||
# src/ggml*.c -> ggml/src/ggml*.c
|
||||
# src/ggml*.cpp -> ggml/src/ggml*.cpp
|
||||
@@ -151,7 +160,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
||||
cat ggml-src.patch | sed -E \
|
||||
-e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \
|
||||
-e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \
|
||||
-e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \
|
||||
|
||||
@@ -1 +1 @@
|
||||
c7dfe3d174f98b14801f9ed12f129179d3e7b638
|
||||
f06264eda2e2bf6e814db5a32bbf42e0b2b1ed98
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt
|
||||
cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt
|
||||
cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake
|
||||
|
||||
cp -rpv ../ggml/cmake/* ./ggml/cmake/
|
||||
cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/
|
||||
|
||||
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
|
||||
|
||||
+37
-1
@@ -247,6 +247,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
}
|
||||
}
|
||||
|
||||
// get extra buffer types of the CPU
|
||||
// TODO: a more general solution for non-CPU extra buft should be imlpemented in the future
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948
|
||||
std::vector<ggml_backend_buffer_type_t> buft_extra;
|
||||
{
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||
|
||||
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||
|
||||
if (ggml_backend_dev_get_extra_bufts_fn) {
|
||||
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
|
||||
while (extra_bufts && *extra_bufts) {
|
||||
buft_extra.emplace_back(*extra_bufts);
|
||||
++extra_bufts;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add tensors
|
||||
for (auto & it : ab_map) {
|
||||
const std::string & name = it.first;
|
||||
@@ -263,7 +283,23 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
|
||||
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
|
||||
}
|
||||
|
||||
ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
|
||||
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
|
||||
|
||||
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
|
||||
for (auto & ex : buft_extra) {
|
||||
if (ex == buft) {
|
||||
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
buft = ggml_backend_dev_buffer_type(cpu_dev);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
|
||||
|
||||
ggml_context * dev_ctx = ctx_for_buft(buft);
|
||||
// validate tensor shape
|
||||
if (is_token_embd) {
|
||||
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()
|
||||
|
||||
@@ -65,6 +65,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -1043,6 +1045,22 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLM,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_CHATGLM,
|
||||
{
|
||||
@@ -1392,6 +1410,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
||||
@@ -69,6 +69,8 @@ enum llm_arch {
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
+41
-1
@@ -59,6 +59,8 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
|
||||
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
|
||||
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
|
||||
{ "yandex", LLM_CHAT_TEMPLATE_YANDEX },
|
||||
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
||||
};
|
||||
|
||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||
@@ -168,6 +170,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_GIGACHAT;
|
||||
} else if (tmpl_contains("<|role_start|>")) {
|
||||
return LLM_CHAT_TEMPLATE_MEGREZ;
|
||||
} else if (tmpl_contains(" Ассистент:")) {
|
||||
return LLM_CHAT_TEMPLATE_YANDEX;
|
||||
} else if (tmpl_contains("<role>ASSISTANT</role>") && tmpl_contains("'HUMAN'")) {
|
||||
return LLM_CHAT_TEMPLATE_BAILING;
|
||||
}
|
||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||
}
|
||||
@@ -567,6 +573,41 @@ int32_t llm_chat_apply_template(
|
||||
if (add_ass) {
|
||||
ss << "<|role_start|>assistant<|role_end|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_YANDEX) {
|
||||
// Yandex template ("\n\n" is defined as EOT token)
|
||||
|
||||
ss << "<s>";
|
||||
|
||||
for (size_t i = 0; i < chat.size(); i++) {
|
||||
std::string role(chat[i]->role);
|
||||
if (role == "user") {
|
||||
ss << " Пользователь: " << chat[i]->content << "\n\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << " Ассистент: " << chat[i]->content << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
if (add_ass) {
|
||||
ss << " Ассистент:[SEP]";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_BAILING) {
|
||||
// Bailing (Ling) template
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
|
||||
if (role == "user") {
|
||||
role = "HUMAN";
|
||||
} else {
|
||||
std::transform(role.begin(), role.end(), role.begin(), ::toupper);
|
||||
}
|
||||
|
||||
ss << "<role>" << role << "</role>" << message->content;
|
||||
}
|
||||
|
||||
if (add_ass) {
|
||||
ss << "<role>ASSISTANT</role>";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
@@ -585,4 +626,3 @@ int32_t llama_chat_builtin_templates(const char ** output, size_t len) {
|
||||
}
|
||||
return (int32_t) LLM_CHAT_TEMPLATES.size();
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_GRANITE,
|
||||
LLM_CHAT_TEMPLATE_GIGACHAT,
|
||||
LLM_CHAT_TEMPLATE_MEGREZ,
|
||||
LLM_CHAT_TEMPLATE_YANDEX,
|
||||
LLM_CHAT_TEMPLATE_BAILING,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
};
|
||||
|
||||
|
||||
@@ -1317,8 +1317,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
// find KV slot
|
||||
{
|
||||
kv_self_update();
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
@@ -2316,11 +2316,6 @@ llama_context * llama_init_from_model(
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||
return nullptr;
|
||||
|
||||
+70
-104
@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask || self_kq_mask_swa) {
|
||||
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_kv = kv_self->n;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
data = (float *) self_kq_mask->data;
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
// For causal attention, use only the previous KV cells
|
||||
// of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
if (!kv_self->cells[i].has_seq_id(seq_id) || kv_self->cells[i].pos > pos) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
// when using kv cache, the mask needs to match the kv cache size
|
||||
const int64_t n_stride = n_tokens;
|
||||
float * data = nullptr;
|
||||
float * data_swa = nullptr;
|
||||
|
||||
if (self_kq_mask) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
||||
data = (float *) self_kq_mask->data;
|
||||
}
|
||||
|
||||
float * data = (float *) self_kq_mask->data;
|
||||
if (self_kq_mask_swa) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
||||
data_swa = (float *) self_kq_mask_swa->data;
|
||||
}
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
||||
// Causal mask:
|
||||
// xxx-------
|
||||
// xxxx------
|
||||
// xxxxx-----
|
||||
// Non-causal mask:
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
float f;
|
||||
// mask the token if:
|
||||
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
||||
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
||||
) {
|
||||
f = -INFINITY;
|
||||
} else {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(kv_self->cells[i].pos - pos);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
if (data) {
|
||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
|
||||
// may need to cut off old tokens for sliding window
|
||||
if (data_swa) {
|
||||
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data_swa) {
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user