mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 10:07:44 +02:00
Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0644baefde | |||
| 490eb96b88 | |||
| 3bb78133ab | |||
| 79cc0f2daf | |||
| 338085c69e | |||
| 4c61875bf8 | |||
| 4b385bfcf8 | |||
| f488429380 | |||
| 4d688f9ebb | |||
| ff599039a9 | |||
| f486ce9f30 | |||
| 38adc7d469 | |||
| 3b3a948134 | |||
| 6845f7f87f | |||
| fa16e517a3 | |||
| 313493de53 | |||
| b1ff83bbb0 | |||
| 4ae1b7517a | |||
| 4d3daf80f8 | |||
| 914dde72ba | |||
| 3136a849db | |||
| e463bbdf65 | |||
| 53de59f67d | |||
| 9ab072ebbe | |||
| ada90bf2ba | |||
| 0c1f39a9ae | |||
| 73cd5e1b97 | |||
| 8ee538ce73 | |||
| 6d95707827 | |||
| 89181c0b6d | |||
| ceaa89b786 | |||
| 2cce9fddb7 |
+1
-1
@@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
|
||||
1. Explicitly disclose the manner in which AI was employed.
|
||||
2. Perform a comprehensive manual review prior to submitting the pull request.
|
||||
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
|
||||
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
|
||||
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
|
||||
|
||||
For more info, please refer to the [AGENTS.md](AGENTS.md) file.
|
||||
|
||||
|
||||
@@ -534,7 +534,7 @@ xcodebuild -create-xcframework \
|
||||
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-macos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
|
||||
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
|
||||
|
||||
+1
-1
@@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
||||
+21
-100
@@ -1,7 +1,3 @@
|
||||
#if defined(_MSC_VER)
|
||||
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
@@ -9,12 +5,12 @@
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
#include "unicode.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <codecvt>
|
||||
#include <chrono>
|
||||
#include <cstdarg>
|
||||
#include <cstring>
|
||||
@@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::u32string filename_utf32;
|
||||
try {
|
||||
#if defined(__clang__)
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
size_t offset = 0;
|
||||
while (offset < filename.size()) {
|
||||
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
|
||||
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
||||
// or invalid encodings were encountered. Reject such attempts
|
||||
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
||||
if (filename_reencoded != filename) {
|
||||
if (result.status != utf8_parse_result::SUCCESS) {
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
uint32_t c = result.codepoint;
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
for (char32_t c : filename_utf32) {
|
||||
if ((result.bytes_consumed == 2 && c < 0x80) ||
|
||||
(result.bytes_consumed == 3 && c < 0x800) ||
|
||||
(result.bytes_consumed == 4 && c < 0x10000)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check for forbidden codepoints:
|
||||
// - Control characters
|
||||
// - Unicode equivalents of illegal characters
|
||||
// - UTF-16 surrogate pairs
|
||||
// - UTF-8 replacement character
|
||||
// - Byte order mark (BOM)
|
||||
// - Illegal characters: / \ : * ? " < > |
|
||||
if (c <= 0x1F // Control characters (C0)
|
||||
|| c == 0x7F // Control characters (DEL)
|
||||
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
||||
@@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
||||
|| c == 0x2216 // Set Minus (backslash equivalent)
|
||||
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
||||
|| c > 0x10FFFF // Max Unicode limit
|
||||
|| c == 0xFFFD // Replacement Character (UTF-8)
|
||||
|| c == 0xFEFF // Byte Order Mark (BOM)
|
||||
|| c == ':' || c == '*' // Illegal characters
|
||||
@@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|
||||
// Subdirectories not allowed, reject path separators
|
||||
return false;
|
||||
}
|
||||
offset += result.bytes_consumed;
|
||||
}
|
||||
|
||||
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
||||
@@ -1469,66 +1450,6 @@ void common_batch_add(
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
|
||||
size_t i;
|
||||
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
||||
|
||||
return i;
|
||||
}
|
||||
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
|
||||
// check for empty sequences
|
||||
if (a.empty() || b.empty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// get the lengths of the input sequences
|
||||
size_t a_len = a.size();
|
||||
size_t b_len = b.size();
|
||||
|
||||
// initialize the maximum length of the longest common subsequence (LCS)
|
||||
size_t max_length = 0;
|
||||
|
||||
// use two rows instead of a 2D matrix to optimize space
|
||||
std::vector<size_t> prev_row(b_len + 1, 0);
|
||||
std::vector<size_t> curr_row(b_len + 1, 0);
|
||||
|
||||
// iterate through the elements of a
|
||||
for (size_t i = 1; i <= a_len; i++) {
|
||||
// iterate through the elements of b
|
||||
for (size_t j = 1; j <= b_len; j++) {
|
||||
// if elements at the current positions match
|
||||
if (a[i - 1] == b[j - 1]) {
|
||||
// if it's the first element of either sequences, set LCS length to 1
|
||||
if (i == 1 || j == 1) {
|
||||
curr_row[j] = 1;
|
||||
} else {
|
||||
// increment LCS length by 1 compared to the previous element
|
||||
curr_row[j] = prev_row[j - 1] + 1;
|
||||
}
|
||||
|
||||
// update max_length if necessary
|
||||
if (curr_row[j] > max_length) {
|
||||
max_length = curr_row[j];
|
||||
}
|
||||
} else {
|
||||
// reset LCS length if elements don't match
|
||||
curr_row[j] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// update the previous row for the next iteration
|
||||
prev_row = curr_row;
|
||||
}
|
||||
|
||||
// return the maximum length of the LCS
|
||||
return max_length;
|
||||
}
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
@@ -779,16 +779,6 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
||||
// longest common prefix
|
||||
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
// longet common subsequence
|
||||
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
+4
-1
@@ -305,7 +305,10 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
);
|
||||
|
||||
if (!res) {
|
||||
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
|
||||
LOG_ERR("%s: download failed: %s (status: %d)\n",
|
||||
__func__,
|
||||
httplib::to_string(res.error()).c_str(),
|
||||
res ? res->status : -1);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -461,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
// What is sum of the other occurrences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
|
||||
+2
-2
@@ -44,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
@@ -53,7 +53,7 @@ struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
uint16_t key_num; // number of occurrences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
|
||||
+115
-9
@@ -160,8 +160,6 @@ class ModelBase:
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_F16
|
||||
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
@@ -527,6 +525,8 @@ class ModelBase:
|
||||
return ()
|
||||
|
||||
def prepare_tensors(self):
|
||||
self.dequant_model()
|
||||
|
||||
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
|
||||
if self.tensor_map.mapping:
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
@@ -1815,7 +1815,7 @@ class MmprojModel(ModelBase):
|
||||
preprocessor_config: dict[str, Any]
|
||||
global_config: dict[str, Any]
|
||||
|
||||
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
|
||||
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
|
||||
|
||||
has_vision_encoder: bool = True # by default
|
||||
has_audio_encoder: bool = False
|
||||
@@ -1870,7 +1870,15 @@ class MmprojModel(ModelBase):
|
||||
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
|
||||
if preprocessor_config_path.is_file():
|
||||
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
cfg = json.load(f)
|
||||
# move media_proc_cfg to root level for compat
|
||||
if "media_proc_cfg" in cfg:
|
||||
cfg = {
|
||||
**cfg,
|
||||
**cfg["media_proc_cfg"],
|
||||
}
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
# prefer processor_config.json if possible
|
||||
processor_config_path = self.dir_model / "processor_config.json"
|
||||
@@ -1919,10 +1927,10 @@ class MmprojModel(ModelBase):
|
||||
self.image_size = self.find_vparam(["image_size"])
|
||||
self.gguf_writer.add_vision_image_size(self.image_size)
|
||||
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
|
||||
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
|
||||
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
|
||||
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
|
||||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
@@ -7695,6 +7703,7 @@ class DeepseekModel(TextModel):
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"KimiK25ForConditionalGeneration",
|
||||
"YoutuForCausalLM",
|
||||
"YoutuVLForConditionalGeneration",
|
||||
)
|
||||
@@ -7813,8 +7822,8 @@ class DeepseekV2Model(TextModel):
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# skip vision tensors and remove "language_model." for Kimi-VL
|
||||
if "vision_tower" in name or "multi_modal_projector" in name:
|
||||
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
|
||||
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
|
||||
return
|
||||
if name.startswith("siglip2.") or name.startswith("merger."):
|
||||
return
|
||||
@@ -11176,6 +11185,103 @@ class KimiVLModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("KimiK25ForConditionalGeneration")
|
||||
class KimiK25Model(MmprojModel):
|
||||
"""Kimi-K2.5 with MoonViT3d vision encoder"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
|
||||
|
||||
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
|
||||
self.patch_size = self.hparams_vision.get("patch_size", 14)
|
||||
|
||||
# Set image_size for compatibility with base class
|
||||
# Use position embedding dimensions as image_size reference
|
||||
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
|
||||
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
# Base class MmprojModel.set_gguf_parameters() already writes:
|
||||
# - vision_block_count, vision_head_count, vision_embedding_length
|
||||
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
|
||||
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
|
||||
super().set_gguf_parameters()
|
||||
assert self.hparams_vision is not None
|
||||
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
|
||||
|
||||
# Position embedding parameters (for interpolation)
|
||||
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
|
||||
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
|
||||
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
|
||||
|
||||
# Projector parameters
|
||||
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
|
||||
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
|
||||
|
||||
# Image size limits
|
||||
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
|
||||
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
|
||||
min_patches = 8 # reasonable minimum
|
||||
pixels_per_patch = self.patch_size ** 2
|
||||
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
|
||||
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
|
||||
|
||||
@staticmethod
|
||||
def permute(weights: Tensor, n_head: int) -> Tensor:
|
||||
out_dim, in_dim = weights.shape
|
||||
head_dim = out_dim // n_head
|
||||
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
|
||||
w = w.permute(0, 2, 1, 3, 4)
|
||||
return w.reshape(out_dim, in_dim)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Only process vision and projector tensors
|
||||
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
|
||||
|
||||
if not is_vision:
|
||||
return
|
||||
|
||||
assert self.hparams_vision is not None
|
||||
n_head = self.hparams_vision.get("num_attention_heads", 16)
|
||||
|
||||
# Permute Q/K weights/biases from interleaved to split RoPE format
|
||||
# This allows using build_rope_2d at runtime without post-permutation.
|
||||
if "wqkv" in name:
|
||||
out_dim = data_torch.shape[0]
|
||||
qkv_dim = out_dim // 3
|
||||
head_dim = qkv_dim // n_head
|
||||
|
||||
if "weight" in name:
|
||||
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
|
||||
wq = self.permute(wq, n_head)
|
||||
wk = self.permute(wk, n_head)
|
||||
data_torch = torch.cat([wq, wk, wv], dim=0)
|
||||
elif "bias" in name:
|
||||
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
|
||||
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
|
||||
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
|
||||
data_torch = torch.cat([bq, bk, bv], dim=0)
|
||||
|
||||
# Temporal embeddings: (T, 1, C) → (T, C)
|
||||
if "pos_emb.time_weight" in name:
|
||||
T, _, C = data_torch.shape
|
||||
data_torch = data_torch.reshape(T, C)
|
||||
|
||||
# PatchMergerMLP tensor name mapping
|
||||
# proj.0.weight → proj.linear_1.weight
|
||||
# proj.2.weight → proj.linear_2.weight
|
||||
if "mm_projector.proj.0." in name:
|
||||
name = name.replace(".proj.0.", ".proj.linear_1.")
|
||||
elif "mm_projector.proj.2." in name:
|
||||
name = name.replace(".proj.2.", ".proj.linear_2.")
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("CogVLMForCausalLM")
|
||||
class CogVLMVisionModel(MmprojModel):
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ Adapt below build commands accordingly.
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
|
||||
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
GGML_ASSERT(nb00 == sizeof(src0_t));
|
||||
|
||||
const auto [ir0, ir1] = get_thread_range(params, src0);
|
||||
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
|
||||
|
||||
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
}
|
||||
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
vDSP_fn_t vDSP_op = nullptr;
|
||||
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
|
||||
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
|
||||
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
|
||||
if (is_src1_contiguous) {
|
||||
if (is_src1_contiguous_rows) {
|
||||
// src1 is broadcastable across src0 and dst in i1, i2, i3
|
||||
const int64_t nr0 = ne00 / ne10;
|
||||
|
||||
|
||||
+104
-40
@@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_erf_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_gelu_quick_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f32(nc,
|
||||
(float *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
@@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
assert(ggml_is_contiguous_1(src0));
|
||||
assert(ggml_is_contiguous_1(dst));
|
||||
assert(ggml_is_contiguous_rows(src0));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int i3 = ir/(ne02*ne01);
|
||||
const int i2 = (ir - i3*ne02*ne01)/ne01;
|
||||
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
|
||||
|
||||
ggml_vec_silu_f16(nc,
|
||||
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
|
||||
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
|
||||
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int k = 0; k < nc; k++) {
|
||||
|
||||
@@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
|
||||
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*int s0, */ const int s1,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
/*int s00,*/ const int s01,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
/*int s10,*/ const int s11,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
src1_ptrs... src1s) {
|
||||
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
//size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s00 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
|
||||
ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
|
||||
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
|
||||
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
} else {
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
if constexpr (sizeof...(I) > 0) {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00 ,s01, s02, s03,
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
} else {
|
||||
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
|
||||
/* s0, */ s1, s2, s3,
|
||||
/* s00,*/ s01, s02, s03,
|
||||
/* s10,*/ s11, s12, s13);
|
||||
/*s0,*/ s1, s2, s3,
|
||||
s00, s01, s02, s03,
|
||||
s10, s11, s12, s13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
#if defined(GGML_USE_HIP)
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
|
||||
#else
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
#endif
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
|
||||
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
|
||||
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
|
||||
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
|
||||
#else
|
||||
const half * K_h_f16 = K_h;
|
||||
const half * V_h_f16 = V_h;
|
||||
half * KQ_f16 = KQ;
|
||||
half * VKQ_f16 = VKQ;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
for (int i0 = 0; i0 < D; i0 += 16) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
||||
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
KQ_f16 + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
wmma::store_matrix_sync(
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, wmma::mem_col_major);
|
||||
}
|
||||
|
||||
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: add support for non-contigiuos tensors
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0]; // values
|
||||
const struct ggml_tensor * dst = op; // indices
|
||||
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dst->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src0->ne[0] > (16*1024)) {
|
||||
// reject tensors with huge rows for now
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const int32_t * op_params = &op->op_params[0];
|
||||
|
||||
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
|
||||
case GGML_OP_SUB:
|
||||
req->op = HTP_OP_SUB;
|
||||
break;
|
||||
case GGML_OP_DIV:
|
||||
req->op = HTP_OP_DIV;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
|
||||
break;
|
||||
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
req->op = HTP_OP_ARGSORT;
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
template <bool _is_src0_constant>
|
||||
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
switch (t->op) {
|
||||
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
req->op = HTP_OP_SQR;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_SQRT:
|
||||
req->op = HTP_OP_SQRT;
|
||||
supported = true;
|
||||
break;
|
||||
|
||||
case GGML_OP_UNARY:
|
||||
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
|
||||
req->op = HTP_OP_UNARY_SILU;
|
||||
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
req->op = HTP_OP_GLU_SWIGLU_OAI;
|
||||
supported = true;
|
||||
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
|
||||
req->op = HTP_OP_GLU_GEGLU;
|
||||
supported = true;
|
||||
}
|
||||
break;
|
||||
|
||||
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_SUM_ROWS;
|
||||
|
||||
size_t n_bufs = 0;
|
||||
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
|
||||
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
|
||||
|
||||
return n_bufs;
|
||||
}
|
||||
|
||||
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
|
||||
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
|
||||
req->op = HTP_OP_ROPE;
|
||||
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
case GGML_OP_SCALE:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
|
||||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
|
||||
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
|
||||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
|
||||
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
|
||||
}
|
||||
break;
|
||||
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
|
||||
break;
|
||||
|
||||
default:
|
||||
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
|
||||
}
|
||||
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
supp = ggml_hexagon_supported_binary(sess, op);
|
||||
break;
|
||||
|
||||
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
supp = ggml_hexagon_supported_unary(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SUM_ROWS:
|
||||
supp = ggml_hexagon_supported_sum_rows(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_SOFT_MAX:
|
||||
supp = ggml_hexagon_supported_softmax(sess, op);
|
||||
break;
|
||||
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
case GGML_OP_GLU:
|
||||
{
|
||||
const auto glu_op = ggml_get_glu_op(op);
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
|
||||
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
|
||||
supp = ggml_hexagon_supported_activations(sess, op);
|
||||
}
|
||||
break;
|
||||
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
|
||||
supp = ggml_hexagon_supported_cpy(sess, op);
|
||||
break;
|
||||
|
||||
case GGML_OP_ARGSORT:
|
||||
supp = ggml_hexagon_supported_argsort(sess, op);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include_directories(
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
sum-rows-ops.c
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
argsort-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
|
||||
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
for (uint32_t ib = 0; ib < block_size; ib++) {
|
||||
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
|
||||
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
|
||||
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
|
||||
|
||||
// geglu tanh implementation
|
||||
// geglu(x, g) = gelu(x) * g
|
||||
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
|
||||
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
|
||||
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
|
||||
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
|
||||
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
|
||||
dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
|
||||
src1_row_size_aligned, src1_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
act_op_func = unary_gelu_f32;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
#include <math.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-dma.h"
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
struct htp_argsort_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t nrows_per_thread;
|
||||
};
|
||||
|
||||
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
|
||||
{
|
||||
const HVX_Vector one = Q6_V_vsplat_R(1);
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
|
||||
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
|
||||
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
|
||||
return hvx_vec_get_i32(sum) == 32;
|
||||
}
|
||||
|
||||
// Sorts values and mirrors swaps to indices.
|
||||
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i
|
||||
while (i <= j) {
|
||||
// Check if we have at least one full vector
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
// If all elements are < pivot, we can skip this whole block
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Scalar fallback / cleanup
|
||||
if (values[i] < pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
// Load 32 elements ending at j.
|
||||
// Since we want `values[j] > pivot`, let's load from j-31 to j.
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] > pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
|
||||
if (left >= right) return;
|
||||
|
||||
int pivot_idx = (left + right) / 2;
|
||||
float pivot = values[pivot_idx];
|
||||
int i = left;
|
||||
int j = right;
|
||||
|
||||
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
|
||||
|
||||
while (i <= j) {
|
||||
// Vectorized scan for i (values[i] > pivot)
|
||||
while (i <= j) {
|
||||
if (i + 32 <= j) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
|
||||
if (all_greater_f32(vals_vec, pivot_vec)) {
|
||||
i += 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[i] > pivot) {
|
||||
i++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Vectorized scan for j (values[j] < pivot)
|
||||
while (i <= j) {
|
||||
if (j - 32 >= i) {
|
||||
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
|
||||
if (all_greater_f32(pivot_vec, vals_vec)) {
|
||||
j -= 32;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (values[j] < pivot) {
|
||||
j--;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (i <= j) {
|
||||
float tmp_val = values[i];
|
||||
values[i] = values[j];
|
||||
values[j] = tmp_val;
|
||||
|
||||
int32_t tmp_idx = indices[i];
|
||||
indices[i] = indices[j];
|
||||
indices[j] = tmp_idx;
|
||||
i++;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
|
||||
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
|
||||
}
|
||||
|
||||
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
|
||||
struct htp_ops_context * octx = actx->octx;
|
||||
|
||||
// Unpack context
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Scratchpad memory
|
||||
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
|
||||
|
||||
// Dimensions
|
||||
uint32_t ne00 = src0->ne[0];
|
||||
uint32_t ne01 = src0->ne[1];
|
||||
uint32_t ne02 = src0->ne[2];
|
||||
uint32_t ne03 = src0->ne[3];
|
||||
|
||||
uint32_t nb01 = src0->nb[1];
|
||||
//uint32_t nb02 = src0->nb[2];
|
||||
//uint32_t nb03 = src0->nb[3];
|
||||
|
||||
uint32_t nb1 = dst->nb[1];
|
||||
//uint32_t nb2 = dst->nb[2];
|
||||
//uint32_t nb3 = dst->nb[3];
|
||||
|
||||
// Sort order
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
|
||||
|
||||
// Rows to process
|
||||
uint32_t total_rows = ne01 * ne02 * ne03;
|
||||
uint32_t rows_per_thread = actx->nrows_per_thread;
|
||||
uint32_t start_row = rows_per_thread * i;
|
||||
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// Scratchpad layout:
|
||||
// We need space for one row of float data (values) and one row of int32 indices.
|
||||
// values: ne00 * sizeof(float)
|
||||
// indices: ne00 * sizeof(int32_t)
|
||||
// Padded to 128 bytes.
|
||||
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
float * values_buf = (float *) spad;
|
||||
int32_t * indices_buf = (int32_t *) (spad + values_size);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
uint32_t src_offset = r * nb01;
|
||||
uint32_t dst_offset = r * nb1;
|
||||
|
||||
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
|
||||
|
||||
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
|
||||
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
|
||||
|
||||
// Initialize indices
|
||||
for (uint32_t j = 0; j < ne00; j++) {
|
||||
indices_buf[j] = j;
|
||||
}
|
||||
|
||||
// Sort values and mirror swaps to indices
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
} else {
|
||||
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
|
||||
}
|
||||
|
||||
// Copy indices back to DDR
|
||||
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
int op_argsort(struct htp_ops_context * octx) {
|
||||
// Check supported types
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
// Allocate scratchpad
|
||||
// We need 1 row of float + 1 row of int32 per thread.
|
||||
uint32_t ne00 = octx->src0.ne[0];
|
||||
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
|
||||
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
|
||||
size_t spad_per_thread = values_size + indices_size;
|
||||
|
||||
// Make sure we round up to 256 for alignment requirements
|
||||
spad_per_thread = hex_round_up(spad_per_thread, 256);
|
||||
|
||||
size_t total_spad_size = spad_per_thread * octx->n_threads;
|
||||
|
||||
if (octx->ctx->vtcm_size < total_spad_size) {
|
||||
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src0_spad.size = total_spad_size;
|
||||
octx->src0_spad.size_per_thread = spad_per_thread;
|
||||
|
||||
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
|
||||
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
|
||||
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
|
||||
octx->src0.data, octx->dst.data);
|
||||
|
||||
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
|
||||
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
|
||||
|
||||
struct htp_argsort_context actx;
|
||||
actx.octx = octx;
|
||||
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
|
||||
|
||||
// Run jobs
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -42,32 +42,36 @@ enum htp_data_type {
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// These values are manually translated over to HTP
|
||||
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
||||
// Do not reorder first 4 (used as an index)
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT = 4,
|
||||
HTP_OP_MUL_MAT_ID = 5,
|
||||
HTP_OP_RMS_NORM = 6,
|
||||
HTP_OP_UNARY_SILU = 7,
|
||||
HTP_OP_UNARY_GELU = 8,
|
||||
HTP_OP_GLU_SWIGLU = 9,
|
||||
HTP_OP_GLU_SWIGLU_OAI = 10,
|
||||
HTP_OP_SOFTMAX = 11,
|
||||
HTP_OP_ADD_ID = 12,
|
||||
HTP_OP_ROPE = 13,
|
||||
HTP_OP_FLASH_ATTN_EXT = 14,
|
||||
HTP_OP_SET_ROWS = 15,
|
||||
HTP_OP_SCALE = 16,
|
||||
HTP_OP_GET_ROWS = 17,
|
||||
HTP_OP_CPY = 18,
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_UNARY_SILU,
|
||||
HTP_OP_UNARY_GELU,
|
||||
HTP_OP_GLU_SWIGLU,
|
||||
HTP_OP_GLU_SWIGLU_OAI,
|
||||
HTP_OP_GLU_GEGLU,
|
||||
HTP_OP_SOFTMAX,
|
||||
HTP_OP_ADD_ID,
|
||||
HTP_OP_ROPE,
|
||||
HTP_OP_FLASH_ATTN_EXT,
|
||||
HTP_OP_SET_ROWS,
|
||||
HTP_OP_GET_ROWS,
|
||||
HTP_OP_SCALE,
|
||||
HTP_OP_CPY,
|
||||
HTP_OP_ARGSORT,
|
||||
HTP_OP_SQR,
|
||||
HTP_OP_SQRT,
|
||||
HTP_OP_SUM_ROWS,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_type_block_size(uint32_t t) {
|
||||
static inline size_t htp_t_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const char * htp_type_name(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return "fp32";
|
||||
case HTP_TYPE_F16:
|
||||
return "fp16";
|
||||
case HTP_TYPE_Q4_0:
|
||||
return "q4_0";
|
||||
case HTP_TYPE_Q8_0:
|
||||
return "q8_0";
|
||||
case HTP_TYPE_MXFP4:
|
||||
return "mxfp4";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
|
||||
@@ -64,25 +64,12 @@ struct htp_ops_context {
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
|
||||
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
|
||||
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
|
||||
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
|
||||
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
|
||||
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
|
||||
|
||||
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
|
||||
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
|
||||
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
@@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
@@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
|
||||
int op_set_rows(struct htp_ops_context * octx);
|
||||
int op_get_rows(struct htp_ops_context * octx);
|
||||
int op_cpy(struct htp_ops_context * octx);
|
||||
int op_argsort(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
||||
@@ -46,127 +46,76 @@
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// ADD variants
|
||||
// Generic macro to define alignment permutations for an op
|
||||
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
|
||||
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) dst % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src0 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
assert((uintptr_t) src1 % 128 == 0); \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
|
||||
} \
|
||||
|
||||
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
|
||||
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
|
||||
|
||||
// Dispatcher logic
|
||||
#define HVX_BINARY_DISPATCHER(OP_NAME) \
|
||||
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
|
||||
if (hex_is_aligned((void *) dst, 128)) { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_aau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_auu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src0, 128)) { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uau(dst, src0, src1, num_elems); \
|
||||
} else { \
|
||||
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
|
||||
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
|
||||
}
|
||||
|
||||
// SUB variants
|
||||
|
||||
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
|
||||
}
|
||||
|
||||
// MUL variants
|
||||
|
||||
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
assert((unsigned long) src1 % 128 == 0);
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
|
||||
}
|
||||
|
||||
// Dispatchers
|
||||
|
||||
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_add_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_add_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_sub_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_sub_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_aa(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_au(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
|
||||
hvx_mul_f32_ua(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
hvx_mul_f32_uu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
HVX_BINARY_DISPATCHER(hvx_add_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_sub_f32)
|
||||
HVX_BINARY_DISPATCHER(hvx_mul_f32)
|
||||
|
||||
// Mul-Mul Optimized
|
||||
|
||||
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src0 % 128 == 0);
|
||||
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Square
|
||||
//
|
||||
|
||||
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(float); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * elem_size, v); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src, 128)) {
|
||||
hvx_sqr_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqr_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD
|
||||
#undef HVX_OP_SUB
|
||||
#undef HVX_OP_MUL
|
||||
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
|
||||
#undef hvx_scalar_loop_body
|
||||
#undef HVX_OP_MIN_SCALAR
|
||||
#undef HVX_OP_CLAMP_SCALAR
|
||||
#undef DEFINE_HVX_BINARY_OP_VARIANTS
|
||||
#undef HVX_BINARY_DISPATCHER
|
||||
|
||||
#endif // HVX_ARITH_H
|
||||
|
||||
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
|
||||
int32_t __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 4, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
|
||||
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0); \
|
||||
\
|
||||
const uint32_t elem_size = sizeof(__fp16); \
|
||||
const uint32_t epv = 128 / elem_size; \
|
||||
const uint32_t nvec = n / epv; \
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
#ifndef HVX_DIV_H
|
||||
#define HVX_DIV_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hex-utils.h"
|
||||
#include "hvx-inverse.h"
|
||||
#include "hvx-arith.h"
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
|
||||
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// 3-letter suffix variants
|
||||
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src0 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
assert((uintptr_t) src1 % 128 == 0);
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
|
||||
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
|
||||
if (hex_is_aligned((void *) dst, 128)) {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_aau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_auu(dst, src0, src1, num_elems);
|
||||
}
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src0, 128)) {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uau(dst, src0, src1, num_elems);
|
||||
} else {
|
||||
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
|
||||
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef HVX_OP_MUL
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t epv = 128 / sizeof(float); \
|
||||
const uint32_t nvec = n / epv; \
|
||||
const uint32_t nloe = n % epv; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
|
||||
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
||||
@@ -12,11 +12,17 @@
|
||||
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
//Algorithm :
|
||||
// x2 = input*0.5
|
||||
// y = * (long *) &input
|
||||
// y = 0x5f3759df - (y>>2)
|
||||
// y = 0x5f3759df - (y>>1)
|
||||
// y = y*(threehalfs - x2*y*y)
|
||||
|
||||
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
|
||||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
// Compute sqrt(x) as x*inv_sqrt(x)
|
||||
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP32; \
|
||||
const uint32_t nloe = n % VLEN_FP32; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vdst[i] = sqrt_res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
|
||||
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
|
||||
if ((unsigned long) dst % 128 == 0) {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_aa(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_au(dst, src, num_elems);
|
||||
}
|
||||
} else {
|
||||
if ((unsigned long) src % 128 == 0) {
|
||||
hvx_sqrt_f32_ua(dst, src, num_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32_uu(dst, src, num_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_SQRT_H */
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "hvx-sigmoid.h"
|
||||
#include "hvx-sqrt.h"
|
||||
#include "hvx-arith.h"
|
||||
#include "hvx-div.h"
|
||||
#include "hvx-base.h"
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
||||
@@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
|
||||
// otherwise we'll release it once we're done with the current Op.
|
||||
|
||||
if (ctx->vtcm_inuse) {
|
||||
ctx->vtcm_needs_release = false;
|
||||
ctx->vtcm_needs_release = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_argsort(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[1];
|
||||
|
||||
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[0].fd = bufs[1].fd;
|
||||
rsp_bufs[0].ptr = bufs[1].ptr;
|
||||
rsp_bufs[0].offset = bufs[1].offset;
|
||||
rsp_bufs[0].size = bufs[1].size;
|
||||
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_sum_rows(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
case HTP_OP_DIV:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad binary-req buffer list");
|
||||
continue;
|
||||
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SQR:
|
||||
case HTP_OP_SQRT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_SUM_ROWS:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_sum_rows_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
case HTP_OP_UNARY_GELU:
|
||||
if (n_bufs != 2) {
|
||||
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
case HTP_OP_SOFTMAX:
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||
FARF(ERROR, "Bad act-req buffer list");
|
||||
continue;
|
||||
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
proc_cpy_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ARGSORT:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad argsort-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_argsort_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,115 @@
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_perf.h>
|
||||
|
||||
#include <string.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
\
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
sum_rows_preamble;
|
||||
|
||||
if (octx->src0.type != HTP_TYPE_F32) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
|
||||
@@ -264,15 +264,25 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_UNARY:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_SET_ROWS:
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_REPEAT:
|
||||
return true;
|
||||
default:
|
||||
return ggml_op_is_empty(op);
|
||||
@@ -312,7 +322,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
constexpr int N_FORWARD = 64;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
|
||||
@@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const int64_t n = ggml_nelements(op);
|
||||
int op_num = -1;
|
||||
|
||||
const char * op_str = "undefined";
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE: op_str = "scale"; break;
|
||||
case GGML_OP_FILL: op_str = "fill"; break;
|
||||
case GGML_OP_CLAMP: op_str = "clamp"; break;
|
||||
case GGML_OP_SQR: op_str = "sqr"; break;
|
||||
case GGML_OP_SQRT: op_str = "sqrt"; break;
|
||||
case GGML_OP_SIN: op_str = "sin"; break;
|
||||
case GGML_OP_COS: op_str = "cos"; break;
|
||||
case GGML_OP_LOG: op_str = "log"; break;
|
||||
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
|
||||
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
|
||||
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
|
||||
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
|
||||
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
|
||||
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
|
||||
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
|
||||
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
|
||||
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
|
||||
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
|
||||
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
|
||||
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
|
||||
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
|
||||
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
|
||||
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
|
||||
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
|
||||
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
|
||||
case GGML_UNARY_OP_STEP: op_str = "step"; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
|
||||
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
|
||||
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
|
||||
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
|
||||
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
|
||||
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
|
||||
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
|
||||
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
|
||||
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
|
||||
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
|
||||
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
|
||||
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
|
||||
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
|
||||
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
|
||||
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
|
||||
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
|
||||
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
|
||||
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
const char * suffix = "";
|
||||
if (n % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
|
||||
|
||||
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
|
||||
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.cnt = is_cnt;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const char * op_str = "undefined";
|
||||
int op_num = -1;
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
op_str = "sum_rows"; break;
|
||||
case GGML_OP_MEAN:
|
||||
op_str = "mean"; break;
|
||||
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
|
||||
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s_op=%d", base, op_num);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
if (is_c4) {
|
||||
res.smem *= 4;
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1472,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_L2_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_f32");
|
||||
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
|
||||
|
||||
const char * t0_str = ggml_type_name(op->src[0]->type);
|
||||
const char * t_str = ggml_type_name(op->type);
|
||||
|
||||
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -1486,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
res.c4 = is_c4;
|
||||
res.smem = 32*sizeof(float);
|
||||
|
||||
return res;
|
||||
|
||||
@@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_TANH:
|
||||
@@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_SOFTPLUS:
|
||||
case GGML_UNARY_OP_EXPM1:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -1061,8 +1070,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_FILL:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
@@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_LOG:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SUM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_TRI:
|
||||
@@ -1087,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
return has_simdgroup_reduction &&
|
||||
op->src[0]->type == GGML_TYPE_I32 &&
|
||||
@@ -1161,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
|
||||
@@ -80,7 +80,9 @@
|
||||
#define FC_SSM_CONV 900
|
||||
#define FC_SOLVE_TRI 1000
|
||||
#define FC_COUNT_EQUAL 1100
|
||||
#define FC_BIN 1200
|
||||
#define FC_UNARY 1200
|
||||
#define FC_BIN 1300
|
||||
#define FC_SUM_ROWS 1400
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
@@ -89,6 +91,37 @@
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
#define OP_UNARY_NUM_SCALE 10
|
||||
#define OP_UNARY_NUM_FILL 11
|
||||
#define OP_UNARY_NUM_CLAMP 12
|
||||
#define OP_UNARY_NUM_SQR 13
|
||||
#define OP_UNARY_NUM_SQRT 14
|
||||
#define OP_UNARY_NUM_SIN 15
|
||||
#define OP_UNARY_NUM_COS 16
|
||||
#define OP_UNARY_NUM_LOG 17
|
||||
#define OP_UNARY_NUM_LEAKY_RELU 18
|
||||
|
||||
#define OP_UNARY_NUM_TANH 100
|
||||
#define OP_UNARY_NUM_RELU 101
|
||||
#define OP_UNARY_NUM_SIGMOID 102
|
||||
#define OP_UNARY_NUM_GELU 103
|
||||
#define OP_UNARY_NUM_GELU_ERF 104
|
||||
#define OP_UNARY_NUM_GELU_QUICK 105
|
||||
#define OP_UNARY_NUM_SILU 106
|
||||
#define OP_UNARY_NUM_ELU 107
|
||||
#define OP_UNARY_NUM_NEG 108
|
||||
#define OP_UNARY_NUM_ABS 109
|
||||
#define OP_UNARY_NUM_SGN 110
|
||||
#define OP_UNARY_NUM_STEP 111
|
||||
#define OP_UNARY_NUM_HARDSWISH 112
|
||||
#define OP_UNARY_NUM_HARDSIGMOID 113
|
||||
#define OP_UNARY_NUM_EXP 114
|
||||
#define OP_UNARY_NUM_SOFTPLUS 115
|
||||
#define OP_UNARY_NUM_EXPM1 116
|
||||
|
||||
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
|
||||
#define OP_SUM_ROWS_NUM_MEAN 11
|
||||
|
||||
// kernel argument structs
|
||||
//
|
||||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
|
||||
@@ -124,6 +157,31 @@ typedef struct {
|
||||
int32_t dim;
|
||||
} ggml_metal_kargs_concat;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float slope;
|
||||
float scale;
|
||||
float bias;
|
||||
float val;
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_unary;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
@@ -181,20 +239,6 @@ typedef struct {
|
||||
uint64_t nb3;
|
||||
} ggml_metal_kargs_repeat;
|
||||
|
||||
typedef struct {
|
||||
float scale;
|
||||
float bias;
|
||||
} ggml_metal_kargs_scale;
|
||||
|
||||
typedef struct {
|
||||
float val;
|
||||
} ggml_metal_kargs_fill;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} ggml_metal_kargs_clamp;
|
||||
|
||||
typedef struct {
|
||||
int64_t nk0;
|
||||
int64_t ne00;
|
||||
@@ -498,8 +542,21 @@ typedef struct {
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne00_4;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
int32_t ne03;
|
||||
uint64_t nb00;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float eps;
|
||||
} ggml_metal_kargs_l2_norm;
|
||||
|
||||
@@ -881,10 +938,6 @@ typedef struct {
|
||||
int max_period;
|
||||
} ggml_metal_kargs_timestep_embedding;
|
||||
|
||||
typedef struct {
|
||||
float slope;
|
||||
} ggml_metal_kargs_leaky_relu;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
|
||||
@@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
n_fuse = ggml_metal_op_acc(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_scale(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_FILL:
|
||||
{
|
||||
n_fuse = ggml_metal_op_fill(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
{
|
||||
n_fuse = ggml_metal_op_clamp(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
@@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_top_k(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
n_fuse = ggml_metal_op_tri(ctx, idx);
|
||||
@@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
{
|
||||
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
n_fuse = ggml_metal_op_set(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
@@ -722,119 +714,6 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float scale;
|
||||
float bias;
|
||||
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_scale args = {
|
||||
/*.scale =*/ scale,
|
||||
/*.bias =*/ bias,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
const float val = ggml_get_op_params_f32(op, 0);
|
||||
|
||||
ggml_metal_kargs_fill args = {
|
||||
/*.val =*/ val
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float min;
|
||||
float max;
|
||||
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
|
||||
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_clamp args = {
|
||||
/*.min =*/ min,
|
||||
/*.max =*/ max,
|
||||
};
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -846,19 +725,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_unary args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.slope =*/ 0.0,
|
||||
/*.scale =*/ 0.0,
|
||||
/*.bias =*/ 0.0,
|
||||
/*.val =*/ 0.0,
|
||||
/*.min =*/ 0.0,
|
||||
/*.max =*/ 0.0,
|
||||
};
|
||||
|
||||
if (op->op == GGML_OP_LEAKY_RELU) {
|
||||
args.slope = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_SCALE) {
|
||||
args.scale = ggml_get_op_params_f32(op, 0);
|
||||
args.bias = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_FILL) {
|
||||
args.val = ggml_get_op_params_f32(op, 0);
|
||||
}
|
||||
|
||||
if (op->op == GGML_OP_CLAMP) {
|
||||
args.min = ggml_get_op_params_f32(op, 0);
|
||||
args.max = ggml_get_op_params_f32(op, 1);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
if (pipeline.cnt) {
|
||||
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
} else {
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const int nth = MIN(args.ne00, nth_max);
|
||||
|
||||
const int nk0 = (args.ne00 + nth - 1)/nth;
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -969,6 +908,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -990,21 +934,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
|
||||
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00);
|
||||
nth = std::min(nth, (int) args.ne00);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
@@ -1664,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
|
||||
const size_t offs = ((const int32_t *) op->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ ne00,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
};
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
|
||||
|
||||
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
|
||||
|
||||
int64_t nk0 = ne10;
|
||||
if (ggml_is_quantized(op->src[1]->type)) {
|
||||
nk0 = ne10/16;
|
||||
} else if (ggml_is_quantized(op->type)) {
|
||||
nk0 = ne10/ggml_blck_size(op->type);
|
||||
}
|
||||
|
||||
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
|
||||
if (nth > nk0) {
|
||||
nrptg = (nth + nk0 - 1)/nk0;
|
||||
nth = nk0;
|
||||
|
||||
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = std::min<int>(nth, nk0);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.nk0 =*/ nk0,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb10,
|
||||
/*.nb01 =*/ nb11,
|
||||
/*.nb02 =*/ nb12,
|
||||
/*.nb03 =*/ nb13,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ ggml_element_size(op),
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
/*.nb3 =*/ pnb3,
|
||||
};
|
||||
|
||||
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
|
||||
|
||||
bid_dst.offs += offs;
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
@@ -3044,39 +3121,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, op->op_params, sizeof(float));
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00_4 =*/ ne00/4,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.eps =*/ eps,
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.eps =*/ eps,
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
|
||||
|
||||
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
if (pipeline.c4) {
|
||||
args.ne00 = ne00/4;
|
||||
args.ne0 = ne0/4;
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
nth = std::min(nth, ne00/4);
|
||||
|
||||
const size_t smem = pipeline.smem;
|
||||
|
||||
const int64_t nrows = ggml_nrows(op->src[0]);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -4084,42 +4181,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, op->op_params, sizeof(float));
|
||||
|
||||
ggml_metal_kargs_leaky_relu args = {
|
||||
/*.slope =*/ slope
|
||||
};
|
||||
|
||||
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
|
||||
|
||||
int64_t n = ggml_nelements(op);
|
||||
|
||||
if (n % 4 == 0) {
|
||||
n /= 4;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
||||
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
|
||||
@@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
|
||||
@@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
|
||||
|
||||
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
|
||||
return x*y;
|
||||
}
|
||||
|
||||
static inline float sum(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline float sum(float4 x) {
|
||||
return x[0] + x[1] + x[2] + x[3];
|
||||
}
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
@@ -895,6 +903,210 @@ enum ggml_sort_order {
|
||||
GGML_SORT_ORDER_DESC,
|
||||
};
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
inline T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
template<typename T> T elu_approx(T x);
|
||||
|
||||
template<> inline float elu_approx<float>(float x) {
|
||||
return (x > 0.f) ? x : (exp(x) - 1);
|
||||
}
|
||||
|
||||
template<> inline float4 elu_approx<float4>(float4 x) {
|
||||
float4 res;
|
||||
|
||||
res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
|
||||
constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
|
||||
|
||||
template <typename T0, typename T, typename TC>
|
||||
kernel void kernel_unary_impl(
|
||||
constant ggml_metal_kargs_unary & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
#define FC_OP FC_unary_op
|
||||
#define FC_CNT FC_unary_cnt
|
||||
|
||||
device const T0 * src0_ptr;
|
||||
device T * dst_ptr;
|
||||
|
||||
int i0;
|
||||
|
||||
if (FC_CNT) {
|
||||
i0 = tgpig.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0);
|
||||
dst_ptr = (device T *) (dst);
|
||||
} else {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int k0 = tgpig.x/args.ne01;
|
||||
const int i01 = tgpig.x - k0*args.ne01;
|
||||
|
||||
i0 = k0*ntg.x + tpitg.x;
|
||||
|
||||
src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
|
||||
}
|
||||
|
||||
{
|
||||
//threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
if (!FC_CNT) {
|
||||
if (i0 >= args.ne0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const TC x = (TC) src0_ptr[i0];
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SCALE) {
|
||||
dst_ptr[i0] = (T) (args.scale * x + args.bias);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_FILL) {
|
||||
dst_ptr[i0] = (T) args.val;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_CLAMP) {
|
||||
dst_ptr[i0] = (T) clamp(x, args.min, args.max);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQR) {
|
||||
dst_ptr[i0] = (T) (x * x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SQRT) {
|
||||
dst_ptr[i0] = (T) sqrt(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIN) {
|
||||
dst_ptr[i0] = (T) sin(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_COS) {
|
||||
dst_ptr[i0] = (T) cos(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LOG) {
|
||||
dst_ptr[i0] = (T) log(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
|
||||
dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_TANH) {
|
||||
dst_ptr[i0] = (T) precise::tanh(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_RELU) {
|
||||
dst_ptr[i0] = (T) fmax(0, x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SIGMOID) {
|
||||
dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
|
||||
dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
|
||||
dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SILU) {
|
||||
dst_ptr[i0] = (T) (x / (1 + exp(-x)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ELU) {
|
||||
dst_ptr[i0] = (T) elu_approx(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_NEG) {
|
||||
dst_ptr[i0] = (T) -x;
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_ABS) {
|
||||
dst_ptr[i0] = (T) fabs(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SGN) {
|
||||
dst_ptr[i0] = T(x > 0) - T(x < 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_STEP) {
|
||||
dst_ptr[i0] = T(x > 0);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
|
||||
dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
|
||||
dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXP) {
|
||||
dst_ptr[i0] = (T) exp(x);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
|
||||
dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
|
||||
}
|
||||
|
||||
if (FC_OP == OP_UNARY_NUM_EXPM1) {
|
||||
// TODO: precise implementation
|
||||
dst_ptr[i0] = (T) (exp(x) - 1);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
#undef FC_CNT
|
||||
}
|
||||
|
||||
typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
|
||||
|
||||
template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
|
||||
template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
|
||||
template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
|
||||
template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
|
||||
|
||||
// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
|
||||
constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
|
||||
constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
|
||||
@@ -1114,414 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
|
||||
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
||||
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
||||
|
||||
kernel void kernel_scale_f32(
|
||||
constant ggml_metal_kargs_scale & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_scale_f32_4(
|
||||
constant ggml_metal_kargs_scale & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * args.scale + args.bias;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_fill_f32_4(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
||||
}
|
||||
|
||||
kernel void kernel_clamp_f32_4(
|
||||
constant ggml_metal_kargs_clamp & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = clamp(src0[tpig], args.min, args.max);
|
||||
}
|
||||
|
||||
kernel void kernel_relu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_relu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = max(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sigmoid_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
||||
}
|
||||
|
||||
kernel void kernel_sigmoid_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
|
||||
}
|
||||
|
||||
kernel void kernel_tanh_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = precise::tanh(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_tanh_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = precise::tanh(src0[tpig]);
|
||||
}
|
||||
|
||||
constant float GELU_COEF_A = 0.044715f;
|
||||
constant float GELU_QUICK_COEF = -1.702f;
|
||||
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
|
||||
|
||||
kernel void kernel_gelu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
// BEWARE !!!
|
||||
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
||||
// This was observed with Falcon 7B and 40B models
|
||||
//
|
||||
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_quick_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
|
||||
}
|
||||
|
||||
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
|
||||
// ref: https://www.johndcook.com/blog/python_erf/
|
||||
constant float p_erf = 0.3275911f;
|
||||
constant float a1_erf = 0.254829592f;
|
||||
constant float a2_erf = -0.284496736f;
|
||||
constant float a3_erf = 1.421413741f;
|
||||
constant float a4_erf = -1.453152027f;
|
||||
constant float a5_erf = 1.061405429f;
|
||||
|
||||
template<typename T>
|
||||
T erf_approx(T x) {
|
||||
T sign_x = sign(x);
|
||||
x = fabs(x);
|
||||
T t = 1.0f / (1.0f + p_erf * x);
|
||||
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
|
||||
return sign_x * y;
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_gelu_erf_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
|
||||
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
|
||||
}
|
||||
|
||||
kernel void kernel_silu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_silu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = x / (1.0f + exp(-x));
|
||||
}
|
||||
|
||||
kernel void kernel_elu_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_elu_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
|
||||
dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
|
||||
dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
|
||||
dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_sqr_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqr_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sqrt_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sqrt(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sin_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sin(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_cos_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = cos(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_log_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = log(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_log_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = log(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_neg_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_neg_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
kernel void kernel_abs_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_abs_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = fabs(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sign(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_sgn_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = sign(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_step_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = step(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_step_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = step(0.0f, src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardswish_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_hardsigmoid_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
|
||||
}
|
||||
|
||||
kernel void kernel_exp_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_exp_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_softplus_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
device const float4 & x = src0[tpig];
|
||||
dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_expm1_f32_4(
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = exp(src0[tpig]) - 1.0f;
|
||||
}
|
||||
|
||||
kernel void kernel_reglu_f32(
|
||||
constant ggml_metal_kargs_glu & args,
|
||||
device const char * src0,
|
||||
@@ -1705,33 +1509,35 @@ kernel void kernel_op_sum_f32(
|
||||
}
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
|
||||
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_sum_rows_impl(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
#define FC_OP FC_sum_rows_op
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
const int i3 = tgpig.z;
|
||||
const int i2 = tgpig.y;
|
||||
const int i1 = tgpig.x;
|
||||
|
||||
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
shmem_t[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float sumf = 0;
|
||||
T0 sumf = T0(0.0f);
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
@@ -1742,23 +1548,33 @@ kernel void kernel_sum_rows(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
shmem_t[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = shmem_t[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
|
||||
if (is_same<float4, T0>::value) {
|
||||
dst_row[0] = sum(sumf) / (4*args.ne00);
|
||||
} else {
|
||||
dst_row[0] = sum(sumf) / args.ne00;
|
||||
}
|
||||
} else {
|
||||
dst_row[0] = sum(sumf);
|
||||
}
|
||||
}
|
||||
|
||||
#undef FC_OP
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
|
||||
template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_cumsum_blk(
|
||||
@@ -2639,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
|
||||
const short K = FC_solve_tri_k;
|
||||
const short NP = PAD2(N, NW);
|
||||
|
||||
const int32_t ne02 = args.ne02;
|
||||
const int32_t ne03 = args.ne03;
|
||||
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x*NSG + sgitg;
|
||||
@@ -2928,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
|
||||
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
|
||||
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
|
||||
|
||||
kernel void kernel_l2_norm_f32(
|
||||
template <typename T0, typename T>
|
||||
kernel void kernel_l2_norm_impl(
|
||||
constant ggml_metal_kargs_l2_norm & args,
|
||||
device const char * src0,
|
||||
device char * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort ntg[[threads_per_threadgroup]]) {
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i03 = tgpig.z;
|
||||
const int i02 = tgpig.y;
|
||||
const int i01 = tgpig.x;
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
||||
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// parallel sum
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
sumf += dot(x[i00], x[i00]);
|
||||
}
|
||||
sumf = simd_sum(sumf);
|
||||
@@ -2965,12 +2784,16 @@ kernel void kernel_l2_norm_f32(
|
||||
|
||||
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||
|
||||
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
|
||||
y[i00] = x[i00] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
|
||||
|
||||
template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
|
||||
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
|
||||
|
||||
kernel void kernel_group_norm_f32(
|
||||
constant ggml_metal_kargs_group_norm & args,
|
||||
device const float * src0,
|
||||
@@ -5072,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32(
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||
template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||
|
||||
kernel void kernel_leaky_relu_f32(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float x = src0[tpig];
|
||||
dst[tpig] = x > 0.0f ? x : x * args.slope;
|
||||
}
|
||||
|
||||
kernel void kernel_leaky_relu_f32_4(
|
||||
constant ggml_metal_kargs_leaky_relu & args,
|
||||
device const float4 * src0,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const float4 x = src0[tpig];
|
||||
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
||||
}
|
||||
|
||||
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
||||
|
||||
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
|
||||
@@ -6161,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
|
||||
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
|
||||
|
||||
const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
//const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
||||
@@ -8749,7 +8554,9 @@ kernel void kernel_mul_mm(
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
@@ -8872,8 +8679,8 @@ kernel void kernel_mul_mm(
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
@@ -9122,7 +8929,9 @@ kernel void kernel_mul_mm_id(
|
||||
threadgroup S0 * sa = (threadgroup S0 *)(shmem);
|
||||
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
|
||||
|
||||
#ifdef GGML_METAL_HAS_TENSOR
|
||||
threadgroup float * sc = (threadgroup float *)(shmem);
|
||||
#endif
|
||||
|
||||
constexpr int NR0 = 64;
|
||||
constexpr int NR1 = 32;
|
||||
@@ -9257,8 +9066,8 @@ kernel void kernel_mul_mm_id(
|
||||
const short sx = (tiitg%NL1);
|
||||
const short sy = (tiitg/NL1)/8;
|
||||
|
||||
const short dx = sx;
|
||||
const short dy = sy;
|
||||
//const short dx = sx;
|
||||
//const short dy = sy;
|
||||
|
||||
const short ly = (tiitg/NL1)%8;
|
||||
|
||||
@@ -9939,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32(
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_memset(
|
||||
constant ggml_metal_kargs_fill & args,
|
||||
constant ggml_metal_kargs_memset & args,
|
||||
device T * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = args.val;
|
||||
|
||||
@@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_q4_0_f32_8x_flat
|
||||
mul_mv_q4_0_f32_1d_8x_flat
|
||||
mul_mv_q4_0_f32_1d_16x_flat
|
||||
mul_mv_q4_1_f32
|
||||
mul_mv_q4_1_f32_flat
|
||||
mul_mv_q4_k_f32
|
||||
mul_mv_q6_k_f32
|
||||
mul_mv_q6_k_f32_flat
|
||||
mul_mv_q8_0_f32
|
||||
@@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
|
||||
gemv_moe_mxfp4_f32
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
mul_mm_q8_0_f32_8x4
|
||||
gemv_noshuffle_general_q8_0_f32
|
||||
mul
|
||||
|
||||
@@ -525,6 +525,7 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mm_f16_f32_kq;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
@@ -532,6 +533,9 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_restore_block_q4_0_noshuffle;
|
||||
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32;
|
||||
cl_kernel kernel_mul_mv_q4_1_f32_flat;
|
||||
cl_kernel kernel_mul_mv_q4_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32;
|
||||
cl_kernel kernel_mul_mv_q6_K_f32_flat;
|
||||
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
|
||||
@@ -563,7 +567,10 @@ struct ggml_backend_opencl_context {
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
|
||||
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
|
||||
|
||||
std::vector<ProfilingInfo> profiling_info;
|
||||
|
||||
@@ -886,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
|
||||
@@ -1117,6 +1126,57 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_1_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q4_1_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_1_f32_flat
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q4_1_f32_flat.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q4_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mv_q4_k_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mv_q6_k_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1342,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q4_0_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q4_0_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q4_1_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q4_1_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q8_0_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -1358,6 +1450,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_q6_k_f32_l4_lm
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "mul_mm_q6_k_f32_l4_lm.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
|
||||
#endif
|
||||
cl_program prog =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
|
||||
CL_CHECK(clReleaseProgram(prog));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// mul_mm_f16_f32_kq_kqv
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
@@ -2887,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 {
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_q4_1 {
|
||||
// Quantized values.
|
||||
cl_mem q = nullptr;
|
||||
// Quantized values in image1d_buffer_t.
|
||||
cl_mem q_img = nullptr;
|
||||
// Scales.
|
||||
cl_mem d = nullptr;
|
||||
// Scales in image1d_buffer_t.
|
||||
cl_mem d_img = nullptr;
|
||||
// Min
|
||||
cl_mem m = nullptr;
|
||||
// Min in image1d_buffer_t.
|
||||
cl_mem m_img = nullptr;
|
||||
// Size of quantized values.
|
||||
size_t size_q = 0;
|
||||
// Size of scales.
|
||||
size_t size_d = 0;
|
||||
// Size of min values.
|
||||
size_t size_m = 0;
|
||||
|
||||
~ggml_tensor_extra_cl_q4_1() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
|
||||
// They must be properly released so that the original buffer can be
|
||||
// properly released to avoid memory leak.
|
||||
if (q != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(q));
|
||||
q = nullptr;
|
||||
}
|
||||
if (d != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(d));
|
||||
d = nullptr;
|
||||
}
|
||||
if (m != nullptr) {
|
||||
CL_CHECK(clReleaseMemObject(m));
|
||||
m = nullptr;
|
||||
}
|
||||
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
|
||||
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
|
||||
// So, there is no need to release them here.
|
||||
// TODO: initialize them for non SMALL_PATH path, or remove them.
|
||||
q_img = nullptr;
|
||||
d_img = nullptr;
|
||||
m_img = nullptr;
|
||||
size_q = 0;
|
||||
size_d = 0;
|
||||
size_m = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_tensor_extra_cl_mxfp4 {
|
||||
// Quantized values.
|
||||
cl_mem q = nullptr;
|
||||
@@ -3363,7 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||
return true;
|
||||
} else if (op->src[0]->type == GGML_TYPE_F32) {
|
||||
return op->src[1]->type == GGML_TYPE_F32;
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_Q4_K ||
|
||||
op->src[0]->type == GGML_TYPE_Q6_K) {
|
||||
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
@@ -3592,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context {
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() {
|
||||
ggml_tensor_extra_cl_q4_1 * extra;
|
||||
if (temp_tensor_extras_q4_1.empty()) {
|
||||
extra = new ggml_tensor_extra_cl_q4_1();
|
||||
} else {
|
||||
extra = temp_tensor_extras_q4_1.back();
|
||||
temp_tensor_extras_q4_1.pop_back();
|
||||
}
|
||||
|
||||
temp_tensor_extras_q4_1_in_use.push_back(extra);
|
||||
|
||||
extra->reset();
|
||||
return extra;
|
||||
}
|
||||
|
||||
ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra;
|
||||
if (temp_tensor_extras_mxfp4.empty()) {
|
||||
@@ -3648,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context {
|
||||
}
|
||||
temp_tensor_extras_q4_0_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {
|
||||
temp_tensor_extras_q4_1.push_back(e);
|
||||
}
|
||||
temp_tensor_extras_q4_1_in_use.clear();
|
||||
|
||||
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
|
||||
temp_tensor_extras_mxfp4.push_back(e);
|
||||
}
|
||||
@@ -3673,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context {
|
||||
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
|
||||
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1;
|
||||
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
|
||||
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
|
||||
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
|
||||
@@ -4042,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
return;
|
||||
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_1) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
|
||||
// Allocate the new extra and create aliases from the original.
|
||||
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
|
||||
ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1();
|
||||
|
||||
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
|
||||
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
|
||||
GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
CL_CHECK(clEnqueueWriteBuffer(
|
||||
queue, data_device, CL_TRUE, 0,
|
||||
ggml_nbytes(tensor), data, 0, NULL, NULL));
|
||||
|
||||
cl_buffer_region region;
|
||||
|
||||
// The original tensor memory is divided into scales and quants, i.e.,
|
||||
// we first store scales, mins, then quants.
|
||||
// Create subbuffer for scales.
|
||||
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
|
||||
region.size = size_d;
|
||||
extra->d = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
auto previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for mins.
|
||||
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
|
||||
region.size = size_m;
|
||||
extra->m = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
previous_origin = region.origin;
|
||||
|
||||
// Create subbuffer for quants.
|
||||
region.origin = align_to(previous_origin + size_m, backend_ctx->alignment);
|
||||
region.size = size_q;
|
||||
extra->q = clCreateSubBuffer(
|
||||
extra_orig->data_device, CL_MEM_READ_WRITE,
|
||||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
|
||||
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
|
||||
@@ -4544,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
} else if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_Q4_1) {
|
||||
ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
|
||||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {1, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
if (tensor->type == GGML_TYPE_MXFP4) {
|
||||
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
|
||||
|
||||
cl_int err;
|
||||
@@ -8372,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
|
||||
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
|
||||
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
|
||||
@@ -8885,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q4_0: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q4_1: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
@@ -8927,6 +9296,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
case GGML_TYPE_Q6_K: {
|
||||
if (ne11 < 32) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
|
||||
break;
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
|
||||
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
|
||||
|
||||
int batch_stride_a = ne00*ne01;
|
||||
int batch_stride_b = ne10*ne11;
|
||||
int batch_stride_d = ne0*ne1;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
|
||||
|
||||
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
|
||||
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, 1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
return;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -9181,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q4_1: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3));
|
||||
#else
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
kernel = backend_ctx->kernel_mul_mv_q4_1_f32;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
|
||||
#endif // GGML_OPENCL_SOA_Q
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q8_0: {
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;
|
||||
@@ -9262,7 +9739,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
}
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q4_K: {
|
||||
kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
|
||||
|
||||
if (backend_ctx->gpu_family == INTEL) {
|
||||
nth0 = 16;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else if (backend_ctx->gpu_family == ADRENO) {
|
||||
nth0 = 64;
|
||||
nth1 = 1;
|
||||
ndst = 4;
|
||||
} else {
|
||||
GGML_ASSERT(false && "TODO: Unknown GPU");
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
|
||||
break;
|
||||
}
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
@@ -9424,7 +9936,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
} else if (src0t == GGML_TYPE_Q4_K) {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
|
||||
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
|
||||
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||
} else if (src0t == GGML_TYPE_Q3_K) {
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
} else if (src0t == GGML_TYPE_Q5_K) {
|
||||
|
||||
@@ -46,6 +46,15 @@ struct block_q4_0
|
||||
uint8_t qs[QK4_0 / 2];
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_1
|
||||
//------------------------------------------------------------------------------
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q6_K
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_1
|
||||
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
|
||||
// This kernel does not deshuffle the bits.
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q4_1(
|
||||
global struct block_q4_1 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d,
|
||||
global half * dst_m
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
global half * m = (global half *) dst_m + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
*m = b->m;
|
||||
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q4_1(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global half * src_m,
|
||||
global struct block_q4_1 * dst
|
||||
) {
|
||||
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
global half * m = (global half *) src_m + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
b->m = *m;
|
||||
for (int i = 0; i < QK4_1/2; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_mxfp4
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
|
||||
global uchar4 * src0_q,
|
||||
global half * src0_d,
|
||||
global half * src0_m,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 4;
|
||||
int iqs = idx % 4;
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
float m = (float)src0_m[ib];
|
||||
global uchar4 * qs = src0_q + ib*4 + iqs;
|
||||
uchar4 q = *qs;
|
||||
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
|
||||
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
|
||||
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
|
||||
} else {
|
||||
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#define LOAD_VEC_A 2
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
|
||||
global uchar * src0_ql,
|
||||
global uchar * src0_qh,
|
||||
global char * src0_s,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
|
||||
int ib = idx / 128; // 2 values per idx
|
||||
int iqs = idx % 128; // 0..127
|
||||
|
||||
int n = iqs / 64; // 0,1
|
||||
int b = (iqs % 64) / 32; // 0,1
|
||||
int is_b = (iqs % 16) / 8; // 0,1
|
||||
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
int is = 8 * n + qhshift + is_b; // 0..15
|
||||
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
|
||||
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
|
||||
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
|
||||
} else {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y(
|
||||
global const struct block_q4_1 * qb_curr,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32(
|
||||
global void * src0,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
global void * src0,
|
||||
ulong offset0,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global void*)((global char*)src0 + offset0);
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK4_1 32
|
||||
|
||||
struct block_q4_1 {
|
||||
half d; // delta
|
||||
half m; // min
|
||||
uchar qs[QK4_1 / 2]; // nibbles / quants
|
||||
};
|
||||
|
||||
inline float block_q4_1_dot_y_flat(
|
||||
global const uchar * x,
|
||||
global const half * dh,
|
||||
global const half * mh,
|
||||
float sumy,
|
||||
float16 yl,
|
||||
int il
|
||||
) {
|
||||
float d = *dh;
|
||||
float m = *mh;
|
||||
global const ushort * qs = ((global const ushort *) x + il/2);
|
||||
|
||||
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
|
||||
acc.s0 += yl.s0 * (qs[0] & 0x000F);
|
||||
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
|
||||
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
|
||||
acc.s3 += yl.s9 * (qs[0] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s2 * (qs[1] & 0x000F);
|
||||
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
|
||||
acc.s2 += yl.sa * (qs[1] & 0x00F0);
|
||||
acc.s3 += yl.sb * (qs[1] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s4 * (qs[2] & 0x000F);
|
||||
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
|
||||
acc.s2 += yl.sc * (qs[2] & 0x00F0);
|
||||
acc.s3 += yl.sd * (qs[2] & 0xF000);
|
||||
|
||||
acc.s0 += yl.s6 * (qs[3] & 0x000F);
|
||||
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
|
||||
acc.s2 += yl.se * (qs[3] & 0x00F0);
|
||||
acc.s3 += yl.sf * (qs[3] & 0xF000);
|
||||
|
||||
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
|
||||
}
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // each subgroup works on 4 rows
|
||||
#define N_SIMDGROUP 1 // number of subgroups in a thread group
|
||||
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline void mul_vec_q_n_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
global float * dst,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
const ulong nb = ne00/QK4_1;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
|
||||
// The number of scales/mins is the same as the number of blocks.
|
||||
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
|
||||
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
|
||||
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
|
||||
|
||||
global uchar * x = (global uchar *) src0_q + offset0_q;
|
||||
global half * d = (global half *) src0_d + offset0_dm;
|
||||
global half * m = (global half *) src0_m + offset0_dm;
|
||||
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float16 yl;
|
||||
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
|
||||
|
||||
int ix = get_sub_group_local_id()/2;
|
||||
int il = 8*(get_sub_group_local_id()%2);
|
||||
|
||||
global float * yb = y + ix * QK4_1 + il;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
|
||||
float sumy = 0;
|
||||
|
||||
sumy += yb[0];
|
||||
sumy += yb[1];
|
||||
sumy += yb[2];
|
||||
sumy += yb[3];
|
||||
sumy += yb[4];
|
||||
sumy += yb[5];
|
||||
sumy += yb[6];
|
||||
sumy += yb[7];
|
||||
|
||||
sumy += yb[16];
|
||||
sumy += yb[17];
|
||||
sumy += yb[18];
|
||||
sumy += yb[19];
|
||||
sumy += yb[20];
|
||||
sumy += yb[21];
|
||||
sumy += yb[22];
|
||||
sumy += yb[23];
|
||||
|
||||
|
||||
yl.s0 = yb[0];
|
||||
yl.s1 = yb[1]/256.f;
|
||||
|
||||
yl.s2 = yb[2];
|
||||
yl.s3 = yb[3]/256.f;
|
||||
|
||||
yl.s4 = yb[4];
|
||||
yl.s5 = yb[5]/256.f;
|
||||
|
||||
yl.s6 = yb[6];
|
||||
yl.s7 = yb[7]/256.f;
|
||||
|
||||
yl.s8 = yb[16]/16.f;
|
||||
yl.s9 = yb[17]/4096.f;
|
||||
|
||||
yl.sa = yb[18]/16.f;
|
||||
yl.sb = yb[19]/4096.f;
|
||||
|
||||
yl.sc = yb[20]/16.f;
|
||||
yl.sd = yb[21]/4096.f;
|
||||
|
||||
yl.se = yb[22]/16.f;
|
||||
yl.sf = yb[23]/4096.f;
|
||||
|
||||
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
|
||||
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
|
||||
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
|
||||
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
|
||||
|
||||
yb += QK4_1 * (N_SIMDWIDTH/2);
|
||||
}
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
|
||||
}
|
||||
if (first_row + 1 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
|
||||
}
|
||||
if (first_row + 2 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
|
||||
}
|
||||
if (first_row + 3 < ne01) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_1_f32_flat(
|
||||
global void * src0_q,
|
||||
global void * src0_d,
|
||||
global void * src0_m,
|
||||
global float * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float*)((global char*)src1 + offset1);
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
|
||||
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_K
|
||||
//------------------------------------------------------------------------------
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE 12
|
||||
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
typedef struct {
|
||||
half d; // super-block scale for quantized scales
|
||||
half dmin; // super-block scale for quantized mins
|
||||
|
||||
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uchar qs[QK_K/2]; // 4-bit quants
|
||||
} block_q4_K;
|
||||
|
||||
#undef N_DST
|
||||
#undef N_SIMDGROUP
|
||||
#undef N_SIMDWIDTH
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_DST 4 // number of rows each SIMD group works on
|
||||
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
|
||||
#define N_SIMDWIDTH 16 // SIMD group size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_DST 4
|
||||
#define N_SIMDGROUP 1
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#undef BLOCK_STRIDE
|
||||
// number of (super) blocks each subgroup processes
|
||||
// each thread in a subgroup processes a block (32 weights)
|
||||
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
global char * src0,
|
||||
int offset0,
|
||||
global char * src1,
|
||||
int offset1,
|
||||
global char * dst,
|
||||
int offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
ushort kmask1 = 0x3f3f;
|
||||
ushort kmask2 = 0x0f0f;
|
||||
ushort kmask3 = 0xc0c0;
|
||||
|
||||
int ix = get_sub_group_local_id()/8; // super block index
|
||||
int it = get_sub_group_local_id()%8; // block index (inside super block)
|
||||
int iq = it/4; // 0 or 1 - first or second half of the super block
|
||||
int ir = it%4; // 0...3 - block index in the half super block
|
||||
|
||||
int nb = ne00/QK_K;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
|
||||
|
||||
int i12 = im%ne12;
|
||||
int i13 = im/ne12;
|
||||
|
||||
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
|
||||
|
||||
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
float sumf[N_DST] = {0.f};
|
||||
float all_sum;
|
||||
|
||||
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
|
||||
|
||||
ushort sc16[4];
|
||||
uchar * sc8 = (uchar *)sc16;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+0] = y4[i+0];
|
||||
sumy.s0 += yl[i+0];
|
||||
|
||||
yl[i+8] = y4[i+32];
|
||||
sumy.s1 += yl[i+8];
|
||||
|
||||
yh[i+0] = y4[i+128];
|
||||
sumy.s2 += yh[i+0];
|
||||
|
||||
yh[i+8] = y4[i+160];
|
||||
sumy.s3 += yh[i+8];
|
||||
}
|
||||
|
||||
global ushort * sc = (global ushort *)x[ib].scales + iq;
|
||||
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
global half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
sc16[0] = sc[0] & kmask1;
|
||||
sc16[1] = sc[2] & kmask1;
|
||||
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
||||
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
||||
|
||||
global ushort * q2 = q1 + 32;
|
||||
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
|
||||
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
|
||||
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
|
||||
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
|
||||
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
|
||||
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
|
||||
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
|
||||
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
|
||||
}
|
||||
|
||||
float dall = dh[0];
|
||||
float dmin = dh[1];
|
||||
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
|
||||
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
|
||||
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
|
||||
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
|
||||
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
|
||||
|
||||
q1 += nb01/2;
|
||||
sc += nb01/2;
|
||||
dh += nb01/2;
|
||||
}
|
||||
|
||||
y4 += BLOCK_STRIDE * QK_K;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = sub_group_reduce_add(sumf[row]);
|
||||
if (first_row + row < ne01) {
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -5749,7 +5749,7 @@ static struct ggml_tensor * ggml_unary_impl(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_unary_op op,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(a));
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
|
||||
@@ -3766,6 +3766,7 @@ class VisionProjectorType:
|
||||
VOXTRAL = "voxtral"
|
||||
LFM2 = "lfm2"
|
||||
KIMIVL = "kimivl"
|
||||
KIMIK25 = "kimik25"
|
||||
LIGHTONOCR = "lightonocr"
|
||||
COGVLM = "cogvlm"
|
||||
JANUS_PRO = "janus_pro"
|
||||
|
||||
@@ -1303,6 +1303,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ: (
|
||||
"multi_modal_projector.linear_{bid}",
|
||||
"mm_projector.proj.linear_{bid}", # Kimi-K2.5
|
||||
"visual.merger.mlp.{bid}", # qwen2vl
|
||||
"merger.mlp.{bid}",
|
||||
),
|
||||
@@ -1364,6 +1365,7 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
@@ -1538,6 +1540,7 @@ class TensorNameMap:
|
||||
"multi_modal_projector.norm",
|
||||
"multi_modal_projector.layer_norm",
|
||||
"multi_modal_projector.pre_norm",
|
||||
"mm_projector.pre_norm", # Kimi-K2.5
|
||||
"pre_mm_projector_norm",
|
||||
"model.vision.linear_proj.norm1", # cogvlm
|
||||
"merger.ln_q",
|
||||
|
||||
+3
-3
@@ -482,7 +482,7 @@ extern "C" {
|
||||
enum llama_params_fit_status {
|
||||
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
|
||||
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
|
||||
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path
|
||||
};
|
||||
|
||||
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
|
||||
@@ -1150,9 +1150,9 @@ extern "C" {
|
||||
//
|
||||
|
||||
/// Apply chat template. Inspired by hf apply_chat_template() on python.
|
||||
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
|
||||
///
|
||||
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
|
||||
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
|
||||
/// @param tmpl A Jinja template to use for this chat.
|
||||
/// @param chat Pointer to a list of multiple llama_chat_message
|
||||
/// @param n_msg Number of llama_chat_message in this chat
|
||||
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
|
||||
|
||||
+8
-2
@@ -30,12 +30,18 @@ fi
|
||||
PR=$1
|
||||
[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; }
|
||||
|
||||
url_origin=$(git config --get remote.upstream.url 2>/dev/null) || \
|
||||
url_origin=$(git config --get remote.origin.url) || {
|
||||
echo "error: no remote named 'origin' in this repository"
|
||||
echo "error: no remote named 'upstream' or 'origin' in this repository"
|
||||
exit 1
|
||||
}
|
||||
|
||||
org_repo=$(echo $url_origin | cut -d/ -f4-)
|
||||
# Extract org/repo from either https or ssh format.
|
||||
if [[ $url_origin =~ ^git@ ]]; then
|
||||
org_repo=$(echo $url_origin | cut -d: -f2)
|
||||
else
|
||||
org_repo=$(echo $url_origin | cut -d/ -f4-)
|
||||
fi
|
||||
org_repo=${org_repo%.git}
|
||||
|
||||
echo "org/repo: $org_repo"
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import urllib.request
|
||||
|
||||
HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp",
|
||||
@@ -12,8 +14,8 @@ vendor = {
|
||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
|
||||
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||
"https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE",
|
||||
|
||||
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
|
||||
}
|
||||
|
||||
+90
-99
@@ -677,7 +677,7 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
||||
float * llama_context::get_logits() {
|
||||
output_reorder();
|
||||
|
||||
return logits;
|
||||
return logits.data;
|
||||
}
|
||||
|
||||
int64_t llama_context::output_resolve_row(int32_t i) const {
|
||||
@@ -715,7 +715,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
if (logits == nullptr) {
|
||||
if (logits.data == nullptr) {
|
||||
throw std::runtime_error("no logits");
|
||||
}
|
||||
|
||||
@@ -739,7 +739,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
return logits + j*model.vocab.n_tokens();
|
||||
return logits.data + j*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
@@ -753,11 +753,11 @@ float * llama_context::get_logits_ith(int32_t i) {
|
||||
float * llama_context::get_embeddings() {
|
||||
output_reorder();
|
||||
|
||||
return embd;
|
||||
return embd.data;
|
||||
}
|
||||
|
||||
llama_token * llama_context::get_sampled_tokens() const{
|
||||
return sampling.sampled;
|
||||
return sampling.sampled.data;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
@@ -766,7 +766,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
if (embd == nullptr) {
|
||||
if (embd.data == nullptr) {
|
||||
throw std::runtime_error("no embeddings");
|
||||
}
|
||||
|
||||
@@ -791,7 +791,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
}
|
||||
|
||||
const uint32_t n_embd_out = model.hparams.n_embd_out();
|
||||
return embd + j*n_embd_out;
|
||||
return embd.data + j*n_embd_out;
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
|
||||
#ifndef NDEBUG
|
||||
@@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
||||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.sampled == nullptr) {
|
||||
if (!sampling.sampled.has_data()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = output_resolve_row(idx);
|
||||
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
|
||||
return sampling.sampled[row];
|
||||
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
|
||||
return sampling.sampled.data[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return LLAMA_TOKEN_NULL;
|
||||
@@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
if (!sampling.probs.has_data()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
||||
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.probs + row*model.vocab.n_tokens();
|
||||
return sampling.probs.data + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
@@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
||||
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
if (!sampling.logits.has_data()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
||||
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.logits + row*model.vocab.n_tokens();
|
||||
return sampling.logits.data + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
@@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
||||
|
||||
try {
|
||||
const int64_t row = output_resolve_row(idx);
|
||||
if (sampling.candidates != nullptr &&
|
||||
if (sampling.candidates.has_data() &&
|
||||
(size_t) row < sampling.candidates_count.size() &&
|
||||
sampling.candidates_count[row] > 0) {
|
||||
return sampling.candidates + row*model.vocab.n_tokens();
|
||||
return sampling.candidates.data + row*model.vocab.n_tokens();
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
// fallback to full vocab list
|
||||
@@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
||||
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.candidates == nullptr) {
|
||||
if (!sampling.candidates.has_data()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
||||
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
if (!sampling.logits.has_data()) {
|
||||
return model.vocab.n_tokens();
|
||||
}
|
||||
|
||||
@@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
||||
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
if (!sampling.probs.has_data()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1254,16 +1254,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
||||
// extract logits
|
||||
if (logits && t_logits) {
|
||||
if (logits.data && t_logits) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
GGML_ASSERT(logits.data != nullptr);
|
||||
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float));
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd && t_embd) {
|
||||
if (embd.data && t_embd) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
@@ -1271,11 +1271,11 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
GGML_ASSERT(embd.data != nullptr);
|
||||
const uint32_t n_embd_out = hparams.n_embd_out();
|
||||
|
||||
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
|
||||
GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float));
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
@@ -1323,7 +1323,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
cross.n_embd = t_embd->ne[0];
|
||||
cross.n_enc = t_embd->ne[1];
|
||||
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
||||
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
||||
memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd));
|
||||
|
||||
const auto & batch = balloc->get_batch();
|
||||
|
||||
@@ -1363,11 +1363,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
|
||||
|
||||
static void copy_tensor_async_ints(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * sampled,
|
||||
size_t sampled_size,
|
||||
const buffer_view<llama_token> & sampled,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (sampled == nullptr) {
|
||||
if (!sampled.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints(
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < sampled_size);
|
||||
GGML_ASSERT(row < sampled.size);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
float * dst,
|
||||
const buffer_view<float> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats(
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
float * row_ptr = dst + (size_t) row * stride;
|
||||
float * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of logits/probabilities that were written for this row.
|
||||
@@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats(
|
||||
|
||||
static void copy_tensor_async_candidates(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * dst,
|
||||
const buffer_view<llama_token> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates(
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
llama_token * row_ptr = dst + (size_t) row * stride;
|
||||
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of candidates that were written.
|
||||
@@ -1671,22 +1670,22 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
|
||||
if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
GGML_ASSERT(logits.data != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
float * logits_out = logits.data + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (embd && t_embd && n_outputs > 0) {
|
||||
if (embd.data && t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
@@ -1694,13 +1693,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
GGML_ASSERT(embd.data != nullptr);
|
||||
const uint32_t n_embd_out = hparams.n_embd_out();
|
||||
float * embd_out = embd + n_outputs_prev*n_embd_out;
|
||||
float * embd_out = embd.data + n_outputs_prev*n_embd_out;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
@@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const auto stride = n_vocab;
|
||||
|
||||
// async copy the sampling data from the backend to the host
|
||||
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
|
||||
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
|
||||
|
||||
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
||||
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
||||
@@ -1841,19 +1840,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
size_t backend_float_count = 0;
|
||||
size_t backend_token_count = 0;
|
||||
|
||||
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
|
||||
|
||||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
const bool has_sampling = !sampling.samplers.empty();
|
||||
if (has_sampling) {
|
||||
sampling.logits_size = n_vocab*n_outputs_max;
|
||||
sampling.probs_size = n_vocab*n_outputs_max;
|
||||
sampling.sampled_size = n_outputs_max;
|
||||
sampling.candidates_size = n_vocab*n_outputs_max;
|
||||
|
||||
backend_float_count = sampling.logits_size + sampling.probs_size;
|
||||
backend_token_count = sampling.sampled_size + sampling.candidates_size;
|
||||
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
|
||||
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
|
||||
}
|
||||
|
||||
if (output_ids.empty()) {
|
||||
@@ -1863,7 +1857,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
|
||||
const size_t new_size =
|
||||
(logits_size + embd_size + backend_float_count) * sizeof(float) +
|
||||
(logits.size + embd.size + backend_float_count) * sizeof(float) +
|
||||
( backend_token_count) * sizeof(llama_token);
|
||||
|
||||
// alloc only when more than the current capacity is required
|
||||
@@ -1878,8 +1872,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
// TODO: not needed?
|
||||
buf_output = nullptr;
|
||||
logits = nullptr;
|
||||
embd = nullptr;
|
||||
logits.data = nullptr;
|
||||
embd.data = nullptr;
|
||||
}
|
||||
|
||||
auto * buft = ggml_backend_cpu_buffer_type();
|
||||
@@ -1898,35 +1892,32 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
||||
|
||||
logits = nullptr;
|
||||
embd = nullptr;
|
||||
|
||||
size_t offset = 0;
|
||||
uint8_t * base = (uint8_t *) output_base;
|
||||
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
offset += logits_size * sizeof(float);
|
||||
logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += logits.size * sizeof(float);
|
||||
|
||||
embd = has_embd ? (float *) (base + offset) : nullptr;
|
||||
offset += embd_size * sizeof(float);
|
||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd.size * sizeof(float);
|
||||
|
||||
sampling.logits = nullptr;
|
||||
sampling.probs = nullptr;
|
||||
sampling.sampled = nullptr;
|
||||
sampling.candidates = nullptr;
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.logits.size * sizeof(float);
|
||||
|
||||
sampling.probs = (float *) (base + offset);
|
||||
offset += sampling.probs_size * sizeof(float);
|
||||
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.probs.size * sizeof(float);
|
||||
|
||||
sampling.sampled = (llama_token *) (base + offset);
|
||||
offset += sampling.sampled_size * sizeof(llama_token);
|
||||
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
|
||||
offset += sampling.sampled.size * sizeof(llama_token);
|
||||
|
||||
sampling.candidates = (llama_token *) (base + offset);
|
||||
offset += sampling.candidates_size * sizeof(llama_token);
|
||||
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.candidates.size * sizeof(llama_token);
|
||||
|
||||
// The count vectors keep track of the actual number of logits/probs/candidates
|
||||
// copied from the backend for each output row.
|
||||
@@ -1939,7 +1930,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
|
||||
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
||||
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
||||
}
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
@@ -1958,38 +1949,38 @@ void llama_context::output_reorder() {
|
||||
const uint64_t i0 = output_swaps[s].i0;
|
||||
const uint64_t i1 = output_swaps[s].i1;
|
||||
|
||||
if (logits_size > 0) {
|
||||
if (logits.size > 0) {
|
||||
for (uint64_t k = 0; k < n_vocab; k++) {
|
||||
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
|
||||
std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (embd_size > 0) {
|
||||
if (embd.size > 0) {
|
||||
for (uint64_t k = 0; k < n_embd; k++) {
|
||||
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
|
||||
std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.logits && sampling.logits_size > 0) {
|
||||
if (sampling.logits.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
||||
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.probs && sampling.probs_size > 0) {
|
||||
if (sampling.probs.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
||||
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.candidates && sampling.candidates_size > 0) {
|
||||
if (sampling.candidates.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
||||
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.sampled && sampling.sampled_size > 0) {
|
||||
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
|
||||
if (sampling.sampled.has_data()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.logits_count.empty()) {
|
||||
@@ -2533,12 +2524,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
||||
|
||||
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
||||
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
||||
|
||||
io.write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
io.write(logits, logits_size * sizeof(float));
|
||||
io.write(logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2546,12 +2537,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
||||
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
|
||||
io.write(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (embd_size) {
|
||||
io.write(embd, embd_size * sizeof(float));
|
||||
io.write(embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2619,12 +2610,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
uint64_t logits_size;
|
||||
io.read_to(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (this->logits_size < logits_size) {
|
||||
if (this->logits.size < logits_size) {
|
||||
throw std::runtime_error("logits buffer too small");
|
||||
}
|
||||
|
||||
if (logits_size) {
|
||||
io.read_to(this->logits, logits_size * sizeof(float));
|
||||
io.read_to(this->logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2635,12 +2626,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
uint64_t embd_size;
|
||||
io.read_to(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (this->embd_size < embd_size) {
|
||||
if (this->embd.size < embd_size) {
|
||||
throw std::runtime_error("embeddings buffer too small");
|
||||
}
|
||||
|
||||
if (embd_size) {
|
||||
io.read_to(this->embd, embd_size * sizeof(float));
|
||||
io.read_to(this->embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+7
-16
@@ -4,6 +4,7 @@
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
@@ -269,29 +270,19 @@ private:
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
struct buffer_view<float> embd = {nullptr, 0};
|
||||
|
||||
// TODO: simplify
|
||||
struct sampling_info {
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
||||
llama_token * sampled = nullptr;
|
||||
size_t sampled_size = 0;
|
||||
|
||||
float * probs = nullptr;
|
||||
size_t probs_size = 0;
|
||||
|
||||
llama_token * candidates = nullptr;
|
||||
size_t candidates_size = 0;
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
struct buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
struct buffer_view<float> probs = {nullptr, 0};
|
||||
struct buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
|
||||
@@ -42,7 +42,6 @@ struct llama_hparams {
|
||||
|
||||
uint32_t n_ctx_train; // context size the model was trained on
|
||||
uint32_t n_embd;
|
||||
uint32_t n_embd_features = 0;
|
||||
uint32_t n_layer;
|
||||
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
|
||||
uint32_t n_rot;
|
||||
|
||||
@@ -49,6 +49,16 @@ struct time_meas {
|
||||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct buffer_view {
|
||||
T * data;
|
||||
size_t size = 0;
|
||||
|
||||
bool has_data() const {
|
||||
return data && size > 0;
|
||||
}
|
||||
};
|
||||
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
// TODO: rename to llama_format ?
|
||||
|
||||
+6
-5
@@ -523,7 +523,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false);
|
||||
|
||||
if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
|
||||
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
|
||||
ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd);
|
||||
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl);
|
||||
|
||||
ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd);
|
||||
ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer);
|
||||
@@ -6046,9 +6047,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
} break;
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0);
|
||||
|
||||
conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
|
||||
conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
|
||||
conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0);
|
||||
|
||||
// posnet
|
||||
@@ -6144,8 +6145,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
}
|
||||
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
|
||||
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0);
|
||||
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0);
|
||||
} break;
|
||||
case LLM_ARCH_BAILINGMOE:
|
||||
{
|
||||
|
||||
+81
-43
@@ -1943,7 +1943,11 @@ struct test_unary : public test_case {
|
||||
|
||||
ggml_tensor * a;
|
||||
if (v & 1) {
|
||||
auto ne = ne_a; ne[0] *= 3;
|
||||
auto ne = ne_a;
|
||||
ne[0] *= 3;
|
||||
ne[1] *= 2;
|
||||
ne[2] *= 5;
|
||||
ne[3] *= 4;
|
||||
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
if (grad_supported) {
|
||||
ggml_set_param(a);
|
||||
@@ -2782,9 +2786,10 @@ struct test_set : public test_case {
|
||||
const ggml_type type_dst;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const int dim;
|
||||
const bool inplace;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type_src, type_dst, ne, dim);
|
||||
return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
@@ -2792,8 +2797,8 @@ struct test_set : public test_case {
|
||||
}
|
||||
|
||||
test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {}
|
||||
std::array<int64_t, 4> ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false)
|
||||
: type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
|
||||
@@ -2804,7 +2809,7 @@ struct test_set : public test_case {
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
ne_dst[i] *= 2;
|
||||
}
|
||||
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
|
||||
ggml_set_param(dst);
|
||||
ggml_set_name(dst, "dst");
|
||||
|
||||
@@ -2812,9 +2817,16 @@ struct test_set : public test_case {
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
|
||||
}
|
||||
ggml_tensor * out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
ggml_tensor * out;
|
||||
if (inplace) {
|
||||
out = ggml_set_inplace(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
} else {
|
||||
out = ggml_set(ctx, dst, src,
|
||||
// The backward pass requires setting a contiguous region:
|
||||
src->nb[1], src->nb[2], src->nb[3], offset);
|
||||
}
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@@ -2964,11 +2976,12 @@ struct test_bin_bcast : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const std::array<int, 4> nr;
|
||||
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
|
||||
bool perm1; // permute src1?
|
||||
|
||||
bool run_whole_graph() override { return nf > 1; }
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, nr, nf);
|
||||
return VARS_TO_STR5(type, ne, nr, nf, perm1);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
@@ -2978,8 +2991,9 @@ struct test_bin_bcast : public test_case {
|
||||
test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 1, 1},
|
||||
std::array<int, 4> nr = {1, 2, 1, 1},
|
||||
int nf = 1)
|
||||
: op(op), type(type), ne(ne), nr(nr), nf(nf) {}
|
||||
int nf = 1,
|
||||
bool perm1 = false)
|
||||
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
GGML_ASSERT(nf <= 16);
|
||||
@@ -2989,12 +3003,19 @@ struct test_bin_bcast : public test_case {
|
||||
|
||||
ggml_tensor * b[16];
|
||||
for (int i = 0; i < nf; ++i) {
|
||||
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
if (perm1) {
|
||||
const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now
|
||||
|
||||
b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
|
||||
b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
|
||||
} else {
|
||||
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
}
|
||||
ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str());
|
||||
}
|
||||
|
||||
// The backward pass supports broadcasting only for GGML_ADD:
|
||||
const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1;
|
||||
const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1;
|
||||
if (grad_supported) {
|
||||
ggml_set_param(a);
|
||||
ggml_set_param(b[0]);
|
||||
@@ -7415,11 +7436,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false));
|
||||
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true));
|
||||
}
|
||||
|
||||
// same-type copy
|
||||
@@ -7477,25 +7500,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
|
||||
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
|
||||
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
|
||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
|
||||
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
|
||||
}
|
||||
};
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
|
||||
for (bool perm1 : {false, true}) {
|
||||
add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1);
|
||||
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1);
|
||||
}
|
||||
|
||||
// test case for k_bin_bcast_unravel in CUDA backend
|
||||
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
|
||||
@@ -7882,20 +7907,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_round (type));
|
||||
test_cases.emplace_back(new test_trunc (type));
|
||||
test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sqr (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sqrt (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_log (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_log (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_sin (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_cos (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_clamp (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_floor (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_ceil (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_round (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_round (type, {1024, 1024, 1, 1}));
|
||||
test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3}));
|
||||
test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_trunc (type, {1024, 1024, 1, 1}));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));
|
||||
@@ -8110,24 +8142,30 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_sum());
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||
test_cases.emplace_back(new test_sum_rows());
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_mean());
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
|
||||
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
|
||||
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));
|
||||
|
||||
@@ -19,6 +19,7 @@ add_library(mtmd
|
||||
models/glm4v.cpp
|
||||
models/internvl.cpp
|
||||
models/kimivl.cpp
|
||||
models/kimik25.cpp
|
||||
models/llama4.cpp
|
||||
models/llava.cpp
|
||||
models/minicpmv.cpp
|
||||
|
||||
@@ -235,6 +235,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_LFM2A,
|
||||
PROJECTOR_TYPE_GLM4V,
|
||||
PROJECTOR_TYPE_YOUTUVL,
|
||||
PROJECTOR_TYPE_KIMIK25,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -268,6 +269,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_LFM2A, "lfm2a"},
|
||||
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
|
||||
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
|
||||
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
+86
-4
@@ -673,8 +673,8 @@ ggml_tensor * clip_graph::build_rope_2d(
|
||||
{
|
||||
first = ggml_view_3d(ctx0, cur,
|
||||
n_dim/2, n_head, n_pos,
|
||||
ggml_row_size(cur->type, n_dim),
|
||||
ggml_row_size(cur->type, n_dim*n_head),
|
||||
cur->nb[1],
|
||||
cur->nb[2],
|
||||
0);
|
||||
first = ggml_rope_ext(
|
||||
ctx0,
|
||||
@@ -692,8 +692,8 @@ ggml_tensor * clip_graph::build_rope_2d(
|
||||
{
|
||||
second = ggml_view_3d(ctx0, cur,
|
||||
n_dim/2, n_head, n_pos,
|
||||
ggml_row_size(cur->type, n_dim),
|
||||
ggml_row_size(cur->type, n_dim*n_head),
|
||||
cur->nb[1],
|
||||
cur->nb[2],
|
||||
n_dim/2 * ggml_element_size(cur));
|
||||
second = ggml_rope_ext(
|
||||
ctx0,
|
||||
@@ -826,6 +826,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_kimivl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_kimik25>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_cogvlm>(ctx, img);
|
||||
@@ -1139,6 +1143,22 @@ struct clip_model_loader {
|
||||
hparams.set_limit_image_tokens(8, 1024);
|
||||
hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
|
||||
int min_pixels = 0, max_pixels = 0;
|
||||
get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false);
|
||||
get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false);
|
||||
if (min_pixels > 0 && max_pixels > 0) {
|
||||
hparams.image_min_pixels = min_pixels;
|
||||
hparams.image_max_pixels = max_pixels;
|
||||
hparams.warmup_image_size = static_cast<int>(std::sqrt(max_pixels));
|
||||
} else {
|
||||
hparams.set_limit_image_tokens(2, 4096);
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
{
|
||||
// default value (used by all model sizes in gemma 3 family)
|
||||
@@ -1668,6 +1688,7 @@ struct clip_model_loader {
|
||||
model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
|
||||
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
|
||||
@@ -3165,6 +3186,23 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
|
||||
const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
|
||||
original_size,
|
||||
params.patch_size * params.n_merge,
|
||||
params.image_min_pixels,
|
||||
params.image_max_pixels);
|
||||
const std::array<uint8_t, 3> pad_color = {0, 0, 0};
|
||||
|
||||
clip_image_u8 resized_img;
|
||||
img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color);
|
||||
clip_image_f32_ptr res(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
|
||||
res_imgs->entries.push_back(std::move(res));
|
||||
} break;
|
||||
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
case PROJECTOR_TYPE_LDP:
|
||||
@@ -3373,6 +3411,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
{
|
||||
// dynamic size
|
||||
int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
|
||||
@@ -3714,6 +3753,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
case PROJECTOR_TYPE_LIGHTONOCR:
|
||||
{
|
||||
// set the 2D positions
|
||||
@@ -3850,6 +3890,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
|
||||
}
|
||||
|
||||
// Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set
|
||||
if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) {
|
||||
const int64_t n_embd = embeddings->ne[0];
|
||||
const int64_t n_tokens = embeddings->ne[1];
|
||||
std::vector<float> emb_data(n_embd * n_tokens);
|
||||
ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings));
|
||||
|
||||
LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n");
|
||||
LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens);
|
||||
|
||||
// Print first few values of first token
|
||||
LOG_INF("Token 0 (first 16 values): ");
|
||||
for (int i = 0; i < std::min((int64_t)16, n_embd); i++) {
|
||||
LOG_INF("%.6f ", emb_data[i]);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
|
||||
// Print last few values of first token
|
||||
if (n_embd > 16) {
|
||||
LOG_INF("Token 0 (last 16 values): ");
|
||||
for (int64_t i = n_embd - 16; i < n_embd; i++) {
|
||||
LOG_INF("%.6f ", emb_data[i]);
|
||||
}
|
||||
LOG_INF("\n");
|
||||
}
|
||||
|
||||
// Compute and print statistics
|
||||
float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0];
|
||||
for (size_t i = 0; i < emb_data.size(); i++) {
|
||||
sum += emb_data[i];
|
||||
sum_sq += emb_data[i] * emb_data[i];
|
||||
min_val = std::min(min_val, emb_data[i]);
|
||||
max_val = std::max(max_val, emb_data[i]);
|
||||
}
|
||||
float mean = sum / emb_data.size();
|
||||
float variance = (sum_sq / emb_data.size()) - (mean * mean);
|
||||
LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n",
|
||||
mean, sqrtf(variance), min_val, max_val, sum);
|
||||
LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -3896,6 +3977,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_LFM2:
|
||||
case PROJECTOR_TYPE_KIMIVL:
|
||||
case PROJECTOR_TYPE_KIMIK25:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_COGVLM:
|
||||
return ctx->model.mm_4h_to_h_w->ne[1];
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
#include "models.h"
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
|
||||
// note: this is similar to clip_graph::resize_position_embeddings, major difference is having
|
||||
// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead
|
||||
// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3).
|
||||
ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) {
|
||||
ggml_tensor * pos_embd = model.position_embeddings;
|
||||
const int height = img.ny / patch_size;
|
||||
const int width = img.nx / patch_size;
|
||||
const uint32_t mode = interpolation_mode;
|
||||
|
||||
GGML_ASSERT(pos_embd);
|
||||
|
||||
const int64_t stored_c = pos_embd->ne[0]; // C = 1152
|
||||
const int64_t orig_w = pos_embd->ne[1]; // W = 64
|
||||
const int64_t orig_h = pos_embd->ne[2]; // H = 64
|
||||
|
||||
GGML_ASSERT(stored_c == n_embd);
|
||||
|
||||
if (height == (int)orig_h && width == (int)orig_w) {
|
||||
// No interpolation needed, just flatten to [C, H*W]
|
||||
return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
|
||||
}
|
||||
|
||||
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
|
||||
pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode);
|
||||
pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3);
|
||||
pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);
|
||||
return pos_embd;
|
||||
}
|
||||
|
||||
ggml_cgraph * clip_graph_kimik25::build() {
|
||||
ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_h, "pos_h");
|
||||
ggml_set_input(pos_h);
|
||||
|
||||
ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
|
||||
ggml_set_name(pos_w, "pos_w");
|
||||
ggml_set_input(pos_w);
|
||||
|
||||
ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC);
|
||||
|
||||
// Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but
|
||||
// Q / K are permuted during conversion to use split format.
|
||||
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
|
||||
cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
|
||||
return cur;
|
||||
};
|
||||
|
||||
ggml_tensor * inp = build_inp();
|
||||
|
||||
// I don't know why, but doing this in the build_vit lead to the ggml_add not occurring?
|
||||
// Doing it manually here does work.
|
||||
inp = ggml_add(ctx0, inp, learned_pos_embd);
|
||||
|
||||
ggml_tensor * cur = build_vit(
|
||||
inp, n_patches,
|
||||
NORM_TYPE_NORMAL,
|
||||
hparams.ffn_op,
|
||||
nullptr,
|
||||
add_pos);
|
||||
|
||||
cb(cur, "vit_out", -1);
|
||||
|
||||
{
|
||||
// patch_merger
|
||||
const int scale_factor = model.hparams.n_merge;
|
||||
cur = build_patch_merge_permute(cur, scale_factor);
|
||||
|
||||
// projection norm
|
||||
int proj_inp_dim = cur->ne[0];
|
||||
int n_merged_patches = cur->ne[1];
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
n_embd, n_merged_patches * scale_factor * scale_factor,
|
||||
ggml_row_size(cur->type, n_embd), 0);
|
||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||
cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
|
||||
cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
proj_inp_dim, n_merged_patches,
|
||||
ggml_row_size(cur->type, proj_inp_dim), 0);
|
||||
cb(cur, "proj_inp_normed", -1);
|
||||
|
||||
// projection mlp
|
||||
cur = build_ffn(cur,
|
||||
model.mm_1_w, model.mm_1_b,
|
||||
nullptr, nullptr,
|
||||
model.mm_2_w, model.mm_2_b,
|
||||
FFN_GELU,
|
||||
-1);
|
||||
|
||||
cb(cur, "proj_out", -1);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
@@ -109,3 +109,10 @@ struct clip_graph_mobilenetv5 : clip_graph {
|
||||
ggml_tensor * inp,
|
||||
const mobilenetv5_block & block);
|
||||
};
|
||||
|
||||
struct clip_graph_kimik25 : clip_graph {
|
||||
clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode);
|
||||
};
|
||||
|
||||
@@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp.
|
||||
* Speculative decoding
|
||||
* Easy-to-use web UI
|
||||
|
||||
For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
|
||||
For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291)
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
Binary file not shown.
@@ -139,6 +139,6 @@ sequenceDiagram
|
||||
|
||||
Note over settingsStore: UI-only (not synced):
|
||||
rect rgb(255, 240, 240)
|
||||
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningFormat
|
||||
Note over settingsStore: systemMessage, custom (JSON)<br/>showStatistics, enableContinueGeneration<br/>autoMicOnEmpty, disableAutoScroll<br/>apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch
|
||||
end
|
||||
```
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.205 0 0);
|
||||
--primary-foreground: oklch(0.985 0 0);
|
||||
--secondary: oklch(0.97 0 0);
|
||||
--secondary: oklch(0.95 0 0);
|
||||
--secondary-foreground: oklch(0.205 0 0);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.97 0 0);
|
||||
--accent: oklch(0.95 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.577 0.245 27.325);
|
||||
--border: oklch(0.875 0 0);
|
||||
@@ -37,7 +37,7 @@
|
||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.975 0 0);
|
||||
--code-background: oklch(0.985 0 0);
|
||||
--code-foreground: oklch(0.145 0 0);
|
||||
--layer-popover: 1000000;
|
||||
}
|
||||
@@ -51,7 +51,7 @@
|
||||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.922 0 0);
|
||||
--primary-foreground: oklch(0.205 0 0);
|
||||
--secondary: oklch(0.269 0 0);
|
||||
--secondary: oklch(0.29 0 0);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
@@ -116,12 +116,62 @@
|
||||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
}
|
||||
|
||||
:root {
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@media (min-width: 640px) {
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
}
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
* {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: transparent transparent;
|
||||
transition: scrollbar-color 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover {
|
||||
scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb {
|
||||
background: transparent;
|
||||
border-radius: 3px;
|
||||
transition: background 0.2s ease;
|
||||
}
|
||||
|
||||
*:hover::-webkit-scrollbar-thumb {
|
||||
background: hsl(var(--muted-foreground) / 0.3);
|
||||
}
|
||||
|
||||
*::-webkit-scrollbar-thumb:hover {
|
||||
background: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
<script lang="ts">
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
icon: Component;
|
||||
tooltip: string;
|
||||
variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link';
|
||||
size?: 'default' | 'sm' | 'lg' | 'icon';
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
onclick: () => void;
|
||||
'aria-label'?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
onclick,
|
||||
'aria-label': ariaLabel
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<Button
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
{onclick}
|
||||
class="h-6 w-6 p-0 {className} flex"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<IconComponent class="h-3 w-3" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
@@ -0,0 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { Copy } from '@lucide/svelte';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
ariaLabel?: string;
|
||||
canCopy?: boolean;
|
||||
text: string;
|
||||
}
|
||||
|
||||
let { ariaLabel = 'Copy to clipboard', canCopy = true, text }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Copy
|
||||
class="h-3 w-3 flex-shrink-0 cursor-{canCopy ? 'pointer' : 'not-allowed'}"
|
||||
aria-label={ariaLabel}
|
||||
onclick={() => canCopy && copyToClipboard(text)}
|
||||
/>
|
||||
@@ -0,0 +1,26 @@
|
||||
<script lang="ts">
|
||||
import { X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
|
||||
interface Props {
|
||||
id: string;
|
||||
onRemove?: (id: string) => void;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { id, onRemove, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}"
|
||||
onclick={(e: MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
onRemove?.(id);
|
||||
}}
|
||||
aria-label="Remove file"
|
||||
>
|
||||
<X class="h-3 w-3" />
|
||||
</Button>
|
||||
@@ -0,0 +1,46 @@
|
||||
<script lang="ts">
|
||||
import { Eye } from '@lucide/svelte';
|
||||
import ActionIconCopyToClipboard from '$lib/components/app/actions/ActionIconCopyToClipboard.svelte';
|
||||
import { FileTypeText } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language: string;
|
||||
disabled?: boolean;
|
||||
onPreview?: (code: string, language: string) => void;
|
||||
}
|
||||
|
||||
let { code, language, disabled = false, onPreview }: Props = $props();
|
||||
|
||||
const showPreview = $derived(language?.toLowerCase() === FileTypeText.HTML);
|
||||
|
||||
function handlePreview() {
|
||||
if (disabled) return;
|
||||
onPreview?.(code, language);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="code-block-actions">
|
||||
<div class="copy-code-btn" class:opacity-50={disabled} class:!cursor-not-allowed={disabled}>
|
||||
<ActionIconCopyToClipboard
|
||||
text={code}
|
||||
canCopy={!disabled}
|
||||
ariaLabel={disabled ? 'Code incomplete' : 'Copy code'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if showPreview}
|
||||
<button
|
||||
class="preview-code-btn"
|
||||
class:opacity-50={disabled}
|
||||
class:!cursor-not-allowed={disabled}
|
||||
title={disabled ? 'Code incomplete' : 'Preview code'}
|
||||
aria-label="Preview code"
|
||||
aria-disabled={disabled}
|
||||
type="button"
|
||||
onclick={handlePreview}
|
||||
>
|
||||
<Eye size={16} />
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -0,0 +1,19 @@
|
||||
/**
|
||||
*
|
||||
* ACTIONS
|
||||
*
|
||||
* Small interactive components for user actions.
|
||||
*
|
||||
*/
|
||||
|
||||
/** Styled icon button for action triggers with tooltip. */
|
||||
export { default as ActionIcon } from './ActionIcon.svelte';
|
||||
|
||||
/** Code block actions component (copy, preview). */
|
||||
export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte';
|
||||
|
||||
/** Copy-to-clipboard icon button with click handler. */
|
||||
export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte';
|
||||
|
||||
/** Remove/delete icon button with X icon. */
|
||||
export { default as ActionIconRemove } from './ActionIconRemove.svelte';
|
||||
@@ -0,0 +1,44 @@
|
||||
<script lang="ts">
|
||||
import { BadgeInfo } from '$lib/components/app';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { copyToClipboard } from '$lib/utils';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
icon: Component;
|
||||
value: string | number;
|
||||
tooltipLabel?: string;
|
||||
}
|
||||
|
||||
let { class: className = '', icon: Icon, value, tooltipLabel }: Props = $props();
|
||||
|
||||
function handleClick() {
|
||||
void copyToClipboard(String(value));
|
||||
}
|
||||
</script>
|
||||
|
||||
{#if tooltipLabel}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{tooltipLabel}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<BadgeInfo class={className} onclick={handleClick}>
|
||||
{#snippet icon()}
|
||||
<Icon class="h-3 w-3" />
|
||||
{/snippet}
|
||||
|
||||
{value}
|
||||
</BadgeInfo>
|
||||
{/if}
|
||||
@@ -0,0 +1,27 @@
|
||||
<script lang="ts">
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
import type { Snippet } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children: Snippet;
|
||||
class?: string;
|
||||
icon?: Snippet;
|
||||
onclick?: () => void;
|
||||
}
|
||||
|
||||
let { children, class: className = '', icon, onclick }: Props = $props();
|
||||
</script>
|
||||
|
||||
<button
|
||||
class={cn(
|
||||
'inline-flex cursor-pointer items-center gap-1 rounded-sm bg-muted-foreground/15 px-1.5 py-0.75',
|
||||
className
|
||||
)}
|
||||
{onclick}
|
||||
>
|
||||
{#if icon}
|
||||
{@render icon()}
|
||||
{/if}
|
||||
|
||||
{@render children()}
|
||||
</button>
|
||||
@@ -0,0 +1,39 @@
|
||||
<script lang="ts">
|
||||
import { ModelModality } from '$lib/enums';
|
||||
import { MODALITY_ICONS, MODALITY_LABELS } from '$lib/constants/icons';
|
||||
import { cn } from '$lib/components/ui/utils';
|
||||
|
||||
type DisplayableModality = ModelModality.VISION | ModelModality.AUDIO;
|
||||
|
||||
interface Props {
|
||||
modalities: ModelModality[];
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { modalities, class: className = '' }: Props = $props();
|
||||
|
||||
// Filter to only modalities that have icons (VISION, AUDIO)
|
||||
const displayableModalities = $derived(
|
||||
modalities.filter(
|
||||
(m): m is DisplayableModality => m === ModelModality.VISION || m === ModelModality.AUDIO
|
||||
)
|
||||
);
|
||||
</script>
|
||||
|
||||
{#each displayableModalities as modality, index (index)}
|
||||
{@const IconComponent = MODALITY_ICONS[modality]}
|
||||
{@const label = MODALITY_LABELS[modality]}
|
||||
|
||||
<span
|
||||
class={cn(
|
||||
'inline-flex items-center gap-1 rounded-md bg-muted px-2 py-1 text-xs font-medium',
|
||||
className
|
||||
)}
|
||||
>
|
||||
{#if IconComponent}
|
||||
<IconComponent class="h-3 w-3" />
|
||||
{/if}
|
||||
|
||||
{label}
|
||||
</span>
|
||||
{/each}
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
*
|
||||
* BADGES & INDICATORS
|
||||
*
|
||||
* Small visual indicators for status and metadata.
|
||||
*
|
||||
*/
|
||||
|
||||
/** Badge displaying chat statistics (tokens, timing). */
|
||||
export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte';
|
||||
|
||||
/** Generic info badge with optional tooltip and click handler. */
|
||||
export { default as BadgeInfo } from './BadgeInfo.svelte';
|
||||
|
||||
/** Badge indicating model modality (vision, audio, tools). */
|
||||
export { default as BadgeModality } from './BadgeModality.svelte';
|
||||
@@ -27,11 +27,13 @@
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
initialMessage?: string;
|
||||
isLoading?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onFileUpload?: (files: File[]) => void;
|
||||
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
|
||||
onStop?: () => void;
|
||||
onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
|
||||
showHelperText?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
}
|
||||
@@ -39,11 +41,13 @@
|
||||
let {
|
||||
class: className,
|
||||
disabled = false,
|
||||
initialMessage = '',
|
||||
isLoading = false,
|
||||
onFileRemove,
|
||||
onFileUpload,
|
||||
onSend,
|
||||
onStop,
|
||||
onSystemPromptAdd,
|
||||
showHelperText = true,
|
||||
uploadedFiles = $bindable([])
|
||||
}: Props = $props();
|
||||
@@ -53,15 +57,28 @@
|
||||
let currentConfig = $derived(config());
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state('');
|
||||
let message = $state(initialMessage);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let previousInitialMessage = $state(initialMessage);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
// Sync message when initialMessage prop changes (e.g., after draft restoration)
|
||||
$effect(() => {
|
||||
if (initialMessage !== previousInitialMessage) {
|
||||
message = initialMessage;
|
||||
previousInitialMessage = initialMessage;
|
||||
}
|
||||
});
|
||||
|
||||
function handleSystemPromptClick() {
|
||||
onSystemPromptAdd?.({ message, files: uploadedFiles });
|
||||
}
|
||||
|
||||
// Check if model is selected (in ROUTER mode)
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
@@ -308,6 +325,7 @@
|
||||
onFileUpload={handleFileUpload}
|
||||
onMicClick={handleMicClick}
|
||||
onStop={handleStop}
|
||||
onSystemPromptClick={handleSystemPromptClick}
|
||||
/>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
+21
-1
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Paperclip } from '@lucide/svelte';
|
||||
import { MessageSquare } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
@@ -11,6 +12,7 @@
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -18,7 +20,8 @@
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
onFileUpload
|
||||
onFileUpload,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
const fileUploadTooltipText = $derived.by(() => {
|
||||
@@ -118,6 +121,23 @@
|
||||
</Tooltip.Content>
|
||||
{/if}
|
||||
</Tooltip.Root>
|
||||
<DropdownMenu.Separator />
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item
|
||||
class="flex cursor-pointer items-center gap-2"
|
||||
onclick={() => onSystemPromptClick?.()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
|
||||
<span>System Prompt</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>Add a custom system message for this conversation</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
|
||||
+4
-1
@@ -27,6 +27,7 @@
|
||||
onFileUpload?: () => void;
|
||||
onMicClick?: () => void;
|
||||
onStop?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -39,7 +40,8 @@
|
||||
uploadedFiles = [],
|
||||
onFileUpload,
|
||||
onMicClick,
|
||||
onStop
|
||||
onStop,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
@@ -170,6 +172,7 @@
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
|
||||
<ModelsSelector
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
<script lang="ts">
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { goto } from '$app/navigation';
|
||||
import {
|
||||
chatStore,
|
||||
pendingEditMessageId,
|
||||
clearPendingEditMessageId,
|
||||
removeSystemPromptPlaceholder
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { DatabaseService } from '$lib/services';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants/ui';
|
||||
import { copyToClipboard, isIMEComposing, formatMessageForClipboard } from '$lib/utils';
|
||||
import ChatMessageAssistant from './ChatMessageAssistant.svelte';
|
||||
import ChatMessageUser from './ChatMessageUser.svelte';
|
||||
@@ -92,8 +101,30 @@
|
||||
return null;
|
||||
});
|
||||
|
||||
function handleCancelEdit() {
|
||||
// Auto-start edit mode if this message is the pending edit target
|
||||
$effect(() => {
|
||||
const pendingId = pendingEditMessageId();
|
||||
|
||||
if (pendingId && pendingId === message.id && !isEditing) {
|
||||
handleEdit();
|
||||
clearPendingEditMessageId();
|
||||
}
|
||||
});
|
||||
|
||||
async function handleCancelEdit() {
|
||||
isEditing = false;
|
||||
|
||||
// If canceling a new system message with placeholder content, remove it without deleting children
|
||||
if (message.role === 'system') {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
editedContent = message.content;
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
@@ -114,8 +145,17 @@
|
||||
onCopy?.(message);
|
||||
}
|
||||
|
||||
function handleConfirmDelete() {
|
||||
onDelete?.(message);
|
||||
async function handleConfirmDelete() {
|
||||
if (message.role === 'system') {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
} else {
|
||||
onDelete?.(message);
|
||||
}
|
||||
|
||||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
@@ -126,7 +166,12 @@
|
||||
|
||||
function handleEdit() {
|
||||
isEditing = true;
|
||||
editedContent = message.content;
|
||||
// Clear placeholder content for system messages
|
||||
editedContent =
|
||||
message.role === 'system' && message.content === SYSTEM_MESSAGE_PLACEHOLDER
|
||||
? ''
|
||||
: message.content;
|
||||
textareaElement?.focus();
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
||||
@@ -166,7 +211,26 @@
|
||||
}
|
||||
|
||||
async function handleSaveEdit() {
|
||||
if (message.role === 'user' || message.role === 'system') {
|
||||
if (message.role === 'system') {
|
||||
// System messages: update in place without branching
|
||||
const newContent = editedContent.trim();
|
||||
|
||||
// If content is empty or still the placeholder, remove without deleting children
|
||||
if (!newContent) {
|
||||
const conversationDeleted = await removeSystemPromptPlaceholder(message.id);
|
||||
isEditing = false;
|
||||
if (conversationDeleted) {
|
||||
goto('/');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await DatabaseService.updateMessage(message.id, { content: newContent });
|
||||
const index = conversationsStore.findMessageIndex(message.id);
|
||||
if (index !== -1) {
|
||||
conversationsStore.updateMessageAtIndex(index, { content: newContent });
|
||||
}
|
||||
} else if (message.role === 'user') {
|
||||
const finalExtras = await getMergedExtras();
|
||||
onEditWithBranching?.(message, editedContent.trim(), finalExtras);
|
||||
} else {
|
||||
|
||||
+20
-3
@@ -5,6 +5,7 @@
|
||||
ChatMessageBranchingControls,
|
||||
DialogConfirmation
|
||||
} from '$lib/components/app';
|
||||
import { Switch } from '$lib/components/ui/switch';
|
||||
|
||||
interface Props {
|
||||
role: 'user' | 'assistant';
|
||||
@@ -26,6 +27,9 @@
|
||||
onConfirmDelete: () => void;
|
||||
onNavigateToSibling?: (siblingId: string) => void;
|
||||
onShowDeleteDialogChange: (show: boolean) => void;
|
||||
showRawOutputSwitch?: boolean;
|
||||
rawOutputEnabled?: boolean;
|
||||
onRawOutputToggle?: (enabled: boolean) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -42,7 +46,10 @@
|
||||
onRegenerate,
|
||||
role,
|
||||
siblingInfo = null,
|
||||
showDeleteDialog
|
||||
showDeleteDialog,
|
||||
showRawOutputSwitch = false,
|
||||
rawOutputEnabled = false,
|
||||
onRawOutputToggle
|
||||
}: Props = $props();
|
||||
|
||||
function handleConfirmDelete() {
|
||||
@@ -51,9 +58,9 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-{justify}">
|
||||
<div class="relative {justify === 'start' ? 'mt-2' : ''} flex h-6 items-center justify-between">
|
||||
<div
|
||||
class="absolute top-0 {actionsPosition === 'left'
|
||||
class="{actionsPosition === 'left'
|
||||
? 'left-0'
|
||||
: 'right-0'} flex items-center gap-2 opacity-100 transition-opacity"
|
||||
>
|
||||
@@ -81,6 +88,16 @@
|
||||
<ActionButton icon={Trash2} tooltip="Delete" onclick={onDelete} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if showRawOutputSwitch}
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs text-muted-foreground">Show raw output</span>
|
||||
<Switch
|
||||
checked={rawOutputEnabled}
|
||||
onCheckedChange={(checked) => onRawOutputToggle?.(checked)}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<DialogConfirmation
|
||||
|
||||
+7
-1
@@ -90,6 +90,9 @@
|
||||
|
||||
const processingState = useProcessingState();
|
||||
|
||||
// Local state for raw output toggle (per message)
|
||||
let showRawOutput = $state(false);
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
let isRouter = $derived(isRouterMode());
|
||||
let displayedModel = $derived((): string | null => {
|
||||
@@ -238,7 +241,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if message.role === 'assistant'}
|
||||
{#if config().disableReasoningFormat}
|
||||
{#if showRawOutput}
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
{:else}
|
||||
<MarkdownContent content={messageContent || ''} />
|
||||
@@ -352,6 +355,9 @@
|
||||
{onConfirmDelete}
|
||||
{onNavigateToSibling}
|
||||
{onShowDeleteDialogChange}
|
||||
showRawOutputSwitch={currentConfig.showRawOutputSwitch}
|
||||
rawOutputEnabled={showRawOutput}
|
||||
onRawOutputToggle={(enabled) => (showRawOutput = enabled)}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
+1
-1
@@ -116,7 +116,7 @@
|
||||
|
||||
<Button class="h-8 px-3" onclick={onSaveEdit} disabled={!editedContent.trim()} size="sm">
|
||||
<Check class="mr-1 h-3 w-3" />
|
||||
Send
|
||||
Save
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler
|
||||
} from '$lib/stores/chat.svelte';
|
||||
@@ -71,6 +72,8 @@
|
||||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
@@ -79,7 +82,7 @@
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
@@ -221,6 +224,14 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
chatStore.savePendingDraft(draft.message, draft.files);
|
||||
}
|
||||
|
||||
await chatStore.addSystemPrompt();
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
if (disableAutoScroll || !chatScrollContainer) return;
|
||||
|
||||
@@ -343,6 +354,12 @@
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
uploadedFiles = pendingDraft.files;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
@@ -428,11 +445,13 @@
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl pb-4">
|
||||
<ChatForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
@@ -486,11 +505,13 @@
|
||||
<div in:fly={{ y: 10, duration: 250, delay: hasPropsError ? 0 : 300 }}>
|
||||
<ChatForm
|
||||
disabled={hasPropsError}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={true}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
|
||||
@@ -254,8 +254,13 @@
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'disableReasoningFormat',
|
||||
label: 'Show raw LLM output',
|
||||
key: 'disableReasoningParsing',
|
||||
label: 'Disable reasoning content parsing',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
key: 'showRawOutputSwitch',
|
||||
label: 'Enable raw output toggle',
|
||||
type: 'checkbox'
|
||||
},
|
||||
{
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
<script lang="ts">
|
||||
import ChevronsUpDownIcon from '@lucide/svelte/icons/chevrons-up-down';
|
||||
import * as Collapsible from '$lib/components/ui/collapsible/index.js';
|
||||
import { buttonVariants } from '$lib/components/ui/button/index.js';
|
||||
import { Card } from '$lib/components/ui/card';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
open?: boolean;
|
||||
class?: string;
|
||||
icon?: Component;
|
||||
iconClass?: string;
|
||||
title: string;
|
||||
subtitle?: string;
|
||||
isStreaming?: boolean;
|
||||
onToggle?: () => void;
|
||||
children: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
open = $bindable(false),
|
||||
class: className = '',
|
||||
icon: Icon,
|
||||
iconClass = 'h-4 w-4',
|
||||
title,
|
||||
subtitle,
|
||||
isStreaming = false,
|
||||
onToggle,
|
||||
children
|
||||
}: Props = $props();
|
||||
|
||||
let contentContainer: HTMLDivElement | undefined = $state();
|
||||
const autoScroll = createAutoScrollController();
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setContainer(contentContainer);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
// Only auto-scroll when open and streaming
|
||||
autoScroll.updateInterval(open && isStreaming);
|
||||
});
|
||||
|
||||
function handleScroll() {
|
||||
autoScroll.handleScroll();
|
||||
}
|
||||
</script>
|
||||
|
||||
<Collapsible.Root
|
||||
{open}
|
||||
onOpenChange={(value) => {
|
||||
open = value;
|
||||
onToggle?.();
|
||||
}}
|
||||
class={className}
|
||||
>
|
||||
<Card class="gap-0 border-muted bg-muted/30 py-0">
|
||||
<Collapsible.Trigger class="flex w-full cursor-pointer items-center justify-between p-3">
|
||||
<div class="flex items-center gap-2 text-muted-foreground">
|
||||
{#if Icon}
|
||||
<Icon class={iconClass} />
|
||||
{/if}
|
||||
|
||||
<span class="font-mono text-sm font-medium">{title}</span>
|
||||
|
||||
{#if subtitle}
|
||||
<span class="text-xs italic">{subtitle}</span>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<div
|
||||
class={buttonVariants({
|
||||
variant: 'ghost',
|
||||
size: 'sm',
|
||||
class: 'h-6 w-6 p-0 text-muted-foreground hover:text-foreground'
|
||||
})}
|
||||
>
|
||||
<ChevronsUpDownIcon class="h-4 w-4" />
|
||||
|
||||
<span class="sr-only">Toggle content</span>
|
||||
</div>
|
||||
</Collapsible.Trigger>
|
||||
|
||||
<Collapsible.Content>
|
||||
<div
|
||||
bind:this={contentContainer}
|
||||
class="overflow-y-auto border-t border-muted px-3 pb-3"
|
||||
onscroll={handleScroll}
|
||||
style="min-height: var(--min-message-height); max-height: var(--max-message-height);"
|
||||
>
|
||||
{@render children()}
|
||||
</div>
|
||||
</Collapsible.Content>
|
||||
</Card>
|
||||
</Collapsible.Root>
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,95 @@
|
||||
<script lang="ts">
|
||||
import hljs from 'highlight.js';
|
||||
import { browser } from '$app/environment';
|
||||
import { mode } from 'mode-watcher';
|
||||
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
|
||||
interface Props {
|
||||
code: string;
|
||||
language?: string;
|
||||
class?: string;
|
||||
maxHeight?: string;
|
||||
maxWidth?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
code,
|
||||
language = 'text',
|
||||
class: className = '',
|
||||
maxHeight = '60vh',
|
||||
maxWidth = ''
|
||||
}: Props = $props();
|
||||
|
||||
let highlightedHtml = $state('');
|
||||
|
||||
function loadHighlightTheme(isDark: boolean) {
|
||||
if (!browser) return;
|
||||
|
||||
const existingThemes = document.querySelectorAll('style[data-highlight-theme-preview]');
|
||||
existingThemes.forEach((style) => style.remove());
|
||||
|
||||
const style = document.createElement('style');
|
||||
style.setAttribute('data-highlight-theme-preview', 'true');
|
||||
style.textContent = isDark ? githubDarkCss : githubLightCss;
|
||||
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const currentMode = mode.current;
|
||||
const isDark = currentMode === 'dark';
|
||||
|
||||
loadHighlightTheme(isDark);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (!code) {
|
||||
highlightedHtml = '';
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if the language is supported
|
||||
const lang = language.toLowerCase();
|
||||
const isSupported = hljs.getLanguage(lang);
|
||||
|
||||
if (isSupported) {
|
||||
const result = hljs.highlight(code, { language: lang });
|
||||
highlightedHtml = result.value;
|
||||
} else {
|
||||
// Try auto-detection or fallback to plain text
|
||||
const result = hljs.highlightAuto(code);
|
||||
highlightedHtml = result.value;
|
||||
}
|
||||
} catch {
|
||||
// Fallback to escaped plain text
|
||||
highlightedHtml = code.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>');
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="code-preview-wrapper rounded-lg border border-border bg-muted {className}"
|
||||
style="max-height: {maxHeight}; max-width: {maxWidth};"
|
||||
>
|
||||
<!-- Needs to be formatted as single line for proper rendering -->
|
||||
<pre class="m-0"><code class="hljs text-sm leading-relaxed">{@html highlightedHtml}</code></pre>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
.code-preview-wrapper {
|
||||
font-family:
|
||||
ui-monospace, SFMono-Regular, 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas,
|
||||
'Liberation Mono', Menlo, monospace;
|
||||
}
|
||||
|
||||
.code-preview-wrapper pre {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.code-preview-wrapper code {
|
||||
background: transparent;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
*
|
||||
* CONTENT RENDERING
|
||||
*
|
||||
* Components for rendering rich content: markdown, code, and previews.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **MarkdownContent** - Rich markdown renderer
|
||||
*
|
||||
* Renders markdown content with syntax highlighting, LaTeX math,
|
||||
* tables, links, and code blocks. Optimized for streaming with
|
||||
* incremental block-based rendering.
|
||||
*
|
||||
* **Features:**
|
||||
* - GFM (GitHub Flavored Markdown): tables, task lists, strikethrough
|
||||
* - LaTeX math via KaTeX (`$inline$` and `$$block$$`)
|
||||
* - Syntax highlighting (highlight.js) with language detection
|
||||
* - Code copy buttons with click feedback
|
||||
* - External links open in new tab with security attrs
|
||||
* - Image attachment resolution from message extras
|
||||
* - Dark/light theme support (auto-switching)
|
||||
* - Streaming-optimized incremental rendering
|
||||
* - Code preview dialog for large blocks
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <MarkdownContent content={message.content} attachments={message.extra} />
|
||||
* ```
|
||||
*/
|
||||
export { default as MarkdownContent } from './MarkdownContent.svelte';
|
||||
|
||||
/**
|
||||
* **SyntaxHighlightedCode** - Code syntax highlighting
|
||||
*
|
||||
* Renders code with syntax highlighting using highlight.js.
|
||||
* Supports theme switching and scrollable containers.
|
||||
*
|
||||
* **Features:**
|
||||
* - Auto language detection with fallback
|
||||
* - Dark/light theme auto-switching
|
||||
* - Scrollable container with configurable max dimensions
|
||||
* - Monospace font styling
|
||||
* - Preserves whitespace and formatting
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SyntaxHighlightedCode code={jsonString} language="json" />
|
||||
* ```
|
||||
*/
|
||||
export { default as SyntaxHighlightedCode } from './SyntaxHighlightedCode.svelte';
|
||||
|
||||
/**
|
||||
* **CollapsibleContentBlock** - Expandable content card
|
||||
*
|
||||
* Reusable collapsible card with header, icon, and auto-scroll.
|
||||
* Used for tool calls and reasoning blocks in chat messages.
|
||||
*
|
||||
* **Features:**
|
||||
* - Collapsible content with smooth animation
|
||||
* - Custom icon and title display
|
||||
* - Optional subtitle/status text
|
||||
* - Auto-scroll during streaming (pauses on user scroll)
|
||||
* - Configurable max height with overflow scroll
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <CollapsibleContentBlock
|
||||
* bind:open
|
||||
* icon={BrainIcon}
|
||||
* title="Thinking..."
|
||||
* isStreaming={true}
|
||||
* >
|
||||
* {reasoningContent}
|
||||
* </CollapsibleContentBlock>
|
||||
* ```
|
||||
*/
|
||||
export { default as CollapsibleContentBlock } from './CollapsibleContentBlock.svelte';
|
||||
@@ -17,9 +17,13 @@
|
||||
let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props();
|
||||
|
||||
let searchQuery = $state('');
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(new SvelteSet(conversations.map((c) => c.id)));
|
||||
let selectedIds = $state.raw<SvelteSet<string>>(getInitialSelectedIds());
|
||||
let lastClickedId = $state<string | null>(null);
|
||||
|
||||
function getInitialSelectedIds(): SvelteSet<string> {
|
||||
return new SvelteSet(conversations.map((c) => c.id));
|
||||
}
|
||||
|
||||
let filteredConversations = $derived(
|
||||
conversations.filter((conv) => {
|
||||
const name = conv.name || 'Untitled conversation';
|
||||
@@ -92,7 +96,7 @@
|
||||
}
|
||||
|
||||
function handleCancel() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
selectedIds = getInitialSelectedIds();
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
|
||||
@@ -100,7 +104,7 @@
|
||||
}
|
||||
|
||||
export function reset() {
|
||||
selectedIds = new SvelteSet(conversations.map((c) => c.id));
|
||||
selectedIds = getInitialSelectedIds();
|
||||
searchQuery = '';
|
||||
lastClickedId = null;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
<script lang="ts">
|
||||
import { ChevronLeft, ChevronRight } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
children?: import('svelte').Snippet;
|
||||
gapSize?: string;
|
||||
onScrollableChange?: (isScrollable: boolean) => void;
|
||||
}
|
||||
|
||||
let { class: className = '', children, gapSize = '3', onScrollableChange }: Props = $props();
|
||||
|
||||
let canScrollLeft = $state(false);
|
||||
let canScrollRight = $state(false);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
|
||||
function scrollLeft(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * -0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function scrollRight(event?: MouseEvent) {
|
||||
event?.stopPropagation();
|
||||
event?.preventDefault();
|
||||
|
||||
if (!scrollContainer) return;
|
||||
|
||||
scrollContainer.scrollBy({ left: scrollContainer.clientWidth * 0.67, behavior: 'smooth' });
|
||||
}
|
||||
|
||||
function updateScrollButtons() {
|
||||
if (!scrollContainer) return;
|
||||
|
||||
const { scrollLeft, scrollWidth, clientWidth } = scrollContainer;
|
||||
|
||||
canScrollLeft = scrollLeft > 0;
|
||||
canScrollRight = scrollLeft < scrollWidth - clientWidth - 1;
|
||||
|
||||
const isScrollable = scrollWidth > clientWidth;
|
||||
onScrollableChange?.(isScrollable);
|
||||
}
|
||||
|
||||
export function resetScroll() {
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollLeft = 0;
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (scrollContainer) {
|
||||
setTimeout(() => {
|
||||
updateScrollButtons();
|
||||
}, 0);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="relative {className}">
|
||||
<button
|
||||
class="absolute top-1/2 left-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollLeft
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollLeft}
|
||||
aria-label="Scroll left"
|
||||
>
|
||||
<ChevronLeft class="h-4 w-4" />
|
||||
</button>
|
||||
|
||||
<div
|
||||
class="scrollbar-hide flex items-start gap-{gapSize} overflow-x-auto"
|
||||
bind:this={scrollContainer}
|
||||
onscroll={updateScrollButtons}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
|
||||
<button
|
||||
class="absolute top-1/2 right-4 z-10 flex h-6 w-6 -translate-y-1/2 items-center justify-center rounded-full bg-foreground/15 shadow-md backdrop-blur-xs transition-opacity hover:bg-foreground/35 {canScrollRight
|
||||
? 'opacity-100'
|
||||
: 'pointer-events-none opacity-0'}"
|
||||
onclick={scrollRight}
|
||||
aria-label="Scroll right"
|
||||
>
|
||||
<ChevronRight class="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
let baseClasses =
|
||||
'px-1 pointer-events-none inline-flex select-none items-center gap-0.5 font-sans text-md font-medium opacity-0 transition-opacity -my-1';
|
||||
let variantClasses = variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground';
|
||||
let variantClasses = $derived(
|
||||
variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground'
|
||||
);
|
||||
</script>
|
||||
|
||||
<kbd class="{baseClasses} {variantClasses} {className}">
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
<script lang="ts">
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
text: string;
|
||||
class?: string;
|
||||
}
|
||||
|
||||
let { text, class: className = '' }: Props = $props();
|
||||
|
||||
let textElement: HTMLSpanElement | undefined = $state();
|
||||
let isTruncated = $state(false);
|
||||
|
||||
function checkTruncation() {
|
||||
if (textElement) {
|
||||
isTruncated = textElement.scrollWidth > textElement.clientWidth;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (textElement) {
|
||||
checkTruncation();
|
||||
|
||||
const observer = new ResizeObserver(checkTruncation);
|
||||
observer.observe(textElement);
|
||||
|
||||
return () => observer.disconnect();
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if isTruncated}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class={className}>
|
||||
<span bind:this={textElement} class="block truncate">
|
||||
{text}
|
||||
</span>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content class="z-[9999]">
|
||||
<p>{text}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<span bind:this={textElement} class="{className} block truncate">
|
||||
{text}
|
||||
</span>
|
||||
{/if}
|
||||
@@ -0,0 +1,45 @@
|
||||
/**
|
||||
*
|
||||
* MISC
|
||||
*
|
||||
* Miscellaneous utility components.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **ConversationSelection** - Multi-select conversation picker
|
||||
*
|
||||
* List of conversations with checkboxes for multi-selection.
|
||||
* Used in import/export dialogs for selecting conversations.
|
||||
*
|
||||
* **Features:**
|
||||
* - Search/filter conversations by name
|
||||
* - Select all / deselect all controls
|
||||
* - Shift-click for range selection
|
||||
* - Message count display per conversation
|
||||
* - Mode-specific UI (export vs import)
|
||||
*/
|
||||
export { default as ConversationSelection } from './ConversationSelection.svelte';
|
||||
|
||||
/**
|
||||
* Horizontal scrollable carousel with navigation arrows.
|
||||
* Used for displaying items in a horizontally scrollable container
|
||||
* with left/right navigation buttons that appear on hover.
|
||||
*/
|
||||
export { default as HorizontalScrollCarousel } from './HorizontalScrollCarousel.svelte';
|
||||
|
||||
/**
|
||||
* **TruncatedText** - Text with ellipsis and tooltip
|
||||
*
|
||||
* Displays text with automatic truncation and full content in tooltip.
|
||||
* Useful for long names or paths in constrained spaces.
|
||||
*/
|
||||
export { default as TruncatedText } from './TruncatedText.svelte';
|
||||
|
||||
/**
|
||||
* **KeyboardShortcutInfo** - Keyboard shortcut hint display
|
||||
*
|
||||
* Displays keyboard shortcut hints (e.g., "⌘ + Enter").
|
||||
* Supports special keys like shift, cmd, and custom text.
|
||||
*/
|
||||
export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
|
||||
@@ -0,0 +1,86 @@
|
||||
<script lang="ts">
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface ActionItem {
|
||||
icon: Component;
|
||||
label: string;
|
||||
onclick: (event: Event) => void;
|
||||
variant?: 'default' | 'destructive';
|
||||
disabled?: boolean;
|
||||
shortcut?: string[];
|
||||
separator?: boolean;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
triggerIcon: Component;
|
||||
triggerTooltip?: string;
|
||||
triggerClass?: string;
|
||||
actions: ActionItem[];
|
||||
align?: 'start' | 'center' | 'end';
|
||||
open?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
triggerIcon,
|
||||
triggerTooltip,
|
||||
triggerClass = '',
|
||||
actions,
|
||||
align = 'end',
|
||||
open = $bindable(false)
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<DropdownMenu.Root bind:open>
|
||||
<DropdownMenu.Trigger
|
||||
class="flex h-6 w-6 cursor-pointer items-center justify-center rounded-md p-0 text-sm font-medium transition-colors hover:bg-accent hover:text-accent-foreground focus:bg-accent focus:text-accent-foreground focus:outline-none disabled:pointer-events-none disabled:opacity-50 data-[state=open]:bg-accent data-[state=open]:text-accent-foreground {triggerClass}"
|
||||
onclick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{#if triggerTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
<span class="sr-only">{triggerTooltip}</span>
|
||||
</Tooltip.Trigger>
|
||||
<Tooltip.Content>
|
||||
<p>{triggerTooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render iconComponent(triggerIcon, 'h-3 w-3')}
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content {align} class="z-[999999] w-48">
|
||||
{#each actions as action, index (action.label)}
|
||||
{#if action.separator && index > 0}
|
||||
<DropdownMenu.Separator />
|
||||
{/if}
|
||||
|
||||
<DropdownMenu.Item
|
||||
onclick={action.onclick}
|
||||
variant={action.variant}
|
||||
disabled={action.disabled}
|
||||
class="flex items-center justify-between hover:[&>kbd]:opacity-100"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render iconComponent(
|
||||
action.icon,
|
||||
`h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}`
|
||||
)}
|
||||
{action.label}
|
||||
</div>
|
||||
|
||||
{#if action.shortcut}
|
||||
<KeyboardShortcutInfo keys={action.shortcut} variant={action.variant} />
|
||||
{/if}
|
||||
</DropdownMenu.Item>
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
|
||||
{#snippet iconComponent(IconComponent: Component, className: string)}
|
||||
<IconComponent class={className} />
|
||||
{/snippet}
|
||||
@@ -0,0 +1,50 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from 'svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
|
||||
interface Props {
|
||||
placeholder?: string;
|
||||
searchValue?: string;
|
||||
onSearchChange?: (value: string) => void;
|
||||
onSearchKeyDown?: (event: KeyboardEvent) => void;
|
||||
emptyMessage?: string;
|
||||
isEmpty?: boolean;
|
||||
children: Snippet;
|
||||
footer?: Snippet;
|
||||
}
|
||||
|
||||
let {
|
||||
placeholder = 'Search...',
|
||||
searchValue = $bindable(''),
|
||||
onSearchChange,
|
||||
onSearchKeyDown,
|
||||
emptyMessage = 'No items found',
|
||||
isEmpty = false,
|
||||
children,
|
||||
footer
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="sticky top-0 z-10 mb-2 bg-popover p-1 pt-2">
|
||||
<SearchInput
|
||||
{placeholder}
|
||||
bind:value={searchValue}
|
||||
onInput={onSearchChange}
|
||||
onKeyDown={onSearchKeyDown}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="overflow-y-auto">
|
||||
{@render children()}
|
||||
|
||||
{#if isEmpty}
|
||||
<div class="px-2 py-3 text-center text-sm text-muted-foreground">{emptyMessage}</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if footer}
|
||||
<DropdownMenu.Separator />
|
||||
|
||||
{@render footer()}
|
||||
{/if}
|
||||
@@ -0,0 +1,65 @@
|
||||
/**
|
||||
*
|
||||
* NAVIGATION & MENUS
|
||||
*
|
||||
* Components for dropdown menus and action selection.
|
||||
*
|
||||
*/
|
||||
|
||||
/**
|
||||
* **DropdownMenuSearchable** - Searchable content for dropdown menus
|
||||
*
|
||||
* Renders a search input with filtered content area, empty state, and optional footer.
|
||||
* Designed to be injected into any dropdown container (DropdownMenu.Content,
|
||||
* DropdownMenu.SubContent, etc.) without providing its own Root.
|
||||
*
|
||||
* **Features:**
|
||||
* - Search/filter input
|
||||
* - Keyboard navigation support
|
||||
* - Custom content and footer via snippets
|
||||
* - Empty state message
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <DropdownMenu.Root>
|
||||
* <DropdownMenu.Trigger>...</DropdownMenu.Trigger>
|
||||
* <DropdownMenu.Content class="pt-0">
|
||||
* <DropdownMenuSearchable
|
||||
* bind:searchValue
|
||||
* placeholder="Search..."
|
||||
* isEmpty={filteredItems.length === 0}
|
||||
* >
|
||||
* {#each items as item}<Item {item} />{/each}
|
||||
* </DropdownMenuSearchable>
|
||||
* </DropdownMenu.Content>
|
||||
* </DropdownMenu.Root>
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svelte';
|
||||
|
||||
/**
|
||||
* **DropdownMenuActions** - Multi-action dropdown menu
|
||||
*
|
||||
* Dropdown menu for multiple action options with icons and shortcuts.
|
||||
* Supports destructive variants and keyboard shortcut hints.
|
||||
*
|
||||
* **Features:**
|
||||
* - Configurable trigger icon with tooltip
|
||||
* - Action items with icons and labels
|
||||
* - Destructive variant styling
|
||||
* - Keyboard shortcut display
|
||||
* - Separator support between groups
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <DropdownMenuActions
|
||||
* triggerIcon={MoreHorizontal}
|
||||
* triggerTooltip="More actions"
|
||||
* actions={[
|
||||
* { icon: Edit, label: 'Edit', onclick: handleEdit },
|
||||
* { icon: Trash, label: 'Delete', onclick: handleDelete, variant: 'destructive' }
|
||||
* ]}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user