mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-01 01:57:43 +02:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0cec84f999 | |||
| b2e1427c9b | |||
| 4d99d45084 | |||
| 10e5b148b0 | |||
| 90b2731894 | |||
| aa2d278a11 | |||
| 6c770d16ca | |||
| 8d880ac012 | |||
| 0f1e9d14cc | |||
| 1274fbee9e | |||
| a7b3dee7a5 | |||
| ec947d2b16 | |||
| 0cd4f4720b | |||
| af237f3026 | |||
| 1a5631beaa | |||
| 1dab5f5a44 |
+2
-2
@@ -2427,11 +2427,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
);
|
||||
}
|
||||
if (split_arg.size() == 1) {
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024);
|
||||
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoull(split_arg[0]) * 1024*1024);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < split_arg.size(); i++) {
|
||||
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024;
|
||||
params.fit_params_target[i] = std::stoull(split_arg[i]) * 1024*1024;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FIT_TARGET"));
|
||||
|
||||
+2
-2
@@ -1620,8 +1620,8 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) :
|
||||
src_parser;
|
||||
|
||||
if (src_parser.empty()) {
|
||||
LOG_WRN("No parser definition detected, assuming pure content parser.");
|
||||
if (src_parser.empty()) {
|
||||
LOG_DBG("No parser definition detected, assuming pure content parser.");
|
||||
}
|
||||
|
||||
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str());
|
||||
|
||||
@@ -790,7 +790,7 @@ public:
|
||||
} else if (target.is_array()) {
|
||||
size_t sel_index;
|
||||
try {
|
||||
sel_index = std::stoul(sel);
|
||||
sel_index = std::stoull(sel);
|
||||
} catch (const std::invalid_argument & e) {
|
||||
sel_index = target.size();
|
||||
}
|
||||
|
||||
+20
-4
@@ -4390,15 +4390,31 @@ class Qwen3Model(Qwen2Model):
|
||||
hparams = ModelBase.load_hparams(self.dir_model, is_mistral_format=False)
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# a bit hacky, but currently the only way to detect if this is a rerank model
|
||||
# ref: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
|
||||
if self._is_qwen3_reranker():
|
||||
self._find_rerank_config()
|
||||
|
||||
def _is_qwen3_reranker(self) -> bool:
|
||||
readme_path = self.dir_model / "README.md"
|
||||
readme_text = ""
|
||||
if readme_path.exists():
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
readme_text = f.read()
|
||||
if "# Qwen3-Reranker" in readme_text:
|
||||
self._find_rerank_config()
|
||||
|
||||
name_hints = [
|
||||
str(self.dir_model.name),
|
||||
str(self.hparams.get("_name_or_path", "")),
|
||||
str(self.hparams.get("model_type", "")),
|
||||
str(self.origin_hf_arch or ""),
|
||||
]
|
||||
name_hints = [hint.lower() for hint in name_hints if hint]
|
||||
|
||||
if "# qwen3-reranker" in readme_text.lower() or "# qwen3-vl-reranker" in readme_text.lower():
|
||||
return True
|
||||
|
||||
if any("qwen3-reranker" in hint or "qwen3-vl-reranker" in hint for hint in name_hints):
|
||||
return True
|
||||
|
||||
return "sequenceclassification" in (self.origin_hf_arch or "").lower()
|
||||
|
||||
def set_vocab(self):
|
||||
# deal with intern-s1-mini
|
||||
|
||||
+7
-1
@@ -599,7 +599,13 @@ If KleidiAI is enabled, the output will contain a line similar to:
|
||||
```
|
||||
load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB
|
||||
```
|
||||
KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
|
||||
KleidiAI’s microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm, SVE, and SME. Llama.cpp selects the most efficient kernels at runtime based on detected CPU capabilities.
|
||||
On CPUs that support SME, SME microkernels are enabled automatically using runtime detection.
|
||||
The environment variable GGML_KLEIDIAI_SME can be used to control SME behavior:
|
||||
- Not set: enable SME automatically if supported and detected.
|
||||
- 0: disable SME.
|
||||
- <n> > 0: enable SME and assume <n> available SME units (override auto detection).
|
||||
If SME is not supported by the CPU, SME microkernels are always disabled.
|
||||
|
||||
Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.
|
||||
|
||||
|
||||
+9
-9
@@ -23,7 +23,7 @@ Legend:
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -31,7 +31,7 @@ Legend:
|
||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
@@ -47,7 +47,7 @@ Legend:
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_DELTA_NET | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -64,7 +64,7 @@ Legend:
|
||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
@@ -76,7 +76,7 @@ Legend:
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
@@ -86,7 +86,7 @@ Legend:
|
||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
@@ -97,13 +97,13 @@ Legend:
|
||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
||||
+1689
-6836
File diff suppressed because it is too large
Load Diff
+1137
-12992
File diff suppressed because it is too large
Load Diff
@@ -633,7 +633,7 @@ class SchemaConverter:
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
items = schema.get('items', schema.get('prefixItems'))
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
|
||||
@@ -8,7 +8,12 @@ extern "C" {
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 3
|
||||
#define RPC_PROTO_MINOR_VERSION 6
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define RPC_PROTO_PATCH_VERSION 1
|
||||
|
||||
#ifdef __cplusplus
|
||||
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
|
||||
#endif
|
||||
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
// backend API
|
||||
|
||||
@@ -202,8 +202,9 @@
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
@@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
@@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .required_cpu = */ CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1202
-3
File diff suppressed because it is too large
Load Diff
@@ -28,13 +28,17 @@ template <int K, int N> struct block {
|
||||
// control size
|
||||
static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
|
||||
static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
|
||||
static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
|
||||
static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
|
||||
static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
|
||||
static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
|
||||
|
||||
using block_q4_0x4 = block<4, 4>;
|
||||
using block_q4_0x8 = block<4, 8>;
|
||||
using block_q4_0x16 = block<4, 16>;
|
||||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
using block_q8_0x16 = block<8, 16>;
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -44,7 +48,14 @@ struct block_q4_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
struct block_q4_Kx16 {
|
||||
ggml_half d[16]; // super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // super-block scale for quantized mins
|
||||
uint8_t scales[192]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[2048]; // 4--bit quants
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
@@ -53,6 +64,13 @@ struct block_q2_Kx8 {
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
struct block_q2_Kx16 {
|
||||
ggml_half d[16]; // Super-block scale for quantized scales
|
||||
ggml_half dmin[16]; // Super-block scale for quantized mins
|
||||
uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
|
||||
uint8_t qs[1024]; // Data (16 cols * 64 bytes per block)
|
||||
};
|
||||
static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
@@ -97,6 +115,12 @@ struct block_iq4_nlx8 {
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
||||
|
||||
struct block_iq4_nlx16 {
|
||||
ggml_half d[16]; // deltas for 16 iq4_nl blocks
|
||||
uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
|
||||
struct block_mxfp4x4 {
|
||||
uint8_t e[4];
|
||||
uint8_t qs[QK_MXFP4 * 2];
|
||||
@@ -109,7 +133,6 @@ struct block_mxfp4x8 {
|
||||
};
|
||||
static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
|
||||
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
@@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#if defined __riscv_zvfh
|
||||
void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
|
||||
@@ -75,6 +75,10 @@ struct ggml_metal {
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// error state - set when a command buffer fails during synchronize
|
||||
// once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
|
||||
bool has_error;
|
||||
};
|
||||
|
||||
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
@@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
||||
res->capture_started = false;
|
||||
res->capture_scope = nil;
|
||||
|
||||
res->has_error = false;
|
||||
|
||||
res->gf = nil;
|
||||
res->encode_async = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
@@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
// release this and all remaining command buffers before returning
|
||||
for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
|
||||
[ctx->cmd_bufs_ext[j] release];
|
||||
}
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
|
||||
ctx->has_error = true;
|
||||
return;
|
||||
}
|
||||
|
||||
[cmd_buf release];
|
||||
@@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
|
||||
}
|
||||
|
||||
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
||||
if (ctx->has_error) {
|
||||
GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = MAX(64, 0.1*gf->n_nodes);
|
||||
|
||||
|
||||
@@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) {
|
||||
return true; //Intel GPUs always support FP16.
|
||||
}
|
||||
|
||||
enum class block_reduce_method {
|
||||
MAX,
|
||||
SUM,
|
||||
};
|
||||
|
||||
template<block_reduce_method method_t, typename T, int warp_size>
|
||||
struct block_reduce_policy;
|
||||
|
||||
template <typename T, typename... Ts>
|
||||
inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
|
||||
|
||||
template<typename...>
|
||||
inline constexpr bool ggml_sycl_dependent_false_v = false;
|
||||
|
||||
#define WARP_32_SIZE 32
|
||||
|
||||
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> {
|
||||
static T reduce(T val) {
|
||||
if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) {
|
||||
return warp_reduce_sum<warp_size>(val);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
|
||||
}
|
||||
}
|
||||
|
||||
static T sentinel() {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return 0.0f;
|
||||
} else if constexpr (std::is_same_v<T, sycl::float2>) {
|
||||
return sycl::float2(0.0f, 0.0f);
|
||||
} else if constexpr (std::is_same_v<T, sycl::half2>) {
|
||||
return sycl::half2(0.0f, 0.0f);
|
||||
} else if constexpr (std::is_same_v<T, int>) {
|
||||
return 0;
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> {
|
||||
static T reduce(T val) {
|
||||
if constexpr (is_any<T, float, sycl::half2>) {
|
||||
return warp_reduce_max<warp_size>(val);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
|
||||
}
|
||||
}
|
||||
|
||||
static T sentinel() {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return -INFINITY;
|
||||
} else if constexpr (std::is_same_v<T, sycl::half2>) {
|
||||
return sycl::half2(-INFINITY, -INFINITY);
|
||||
} else {
|
||||
static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <block_reduce_method reduce_method_t, int warp_size, typename T>
|
||||
static T block_reduce(T val, T * shared_vals, int block_size_template) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val);
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
|
||||
if (block_size > warp_size) {
|
||||
assert((block_size <= 1024) && (block_size % warp_size) == 0);
|
||||
const int warp_id = item_ct1.get_local_id(2) / warp_size;
|
||||
const int lane_id = item_ct1.get_local_id(2) % warp_size;
|
||||
if (lane_id == 0) {
|
||||
shared_vals[warp_id] = val;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
float tmp = 0.f;
|
||||
if (lane_id < (static_cast<int>(block_size) / warp_size)) {
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += shared_vals[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
}
|
||||
return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
||||
@@ -39,6 +39,11 @@ template<typename dst_t, typename src_t>
|
||||
return sycl::ext::oneapi::bfloat16(float(x));
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
|
||||
return static_cast<float>(x);
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) {
|
||||
return x.template convert<sycl::half, sycl::rounding_mode::rte>();
|
||||
} else if constexpr (std::is_same_v<src_t, sycl::float2> &&
|
||||
std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) {
|
||||
return {x.x, x.y};
|
||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||
return int32_t(x);
|
||||
} else {
|
||||
@@ -46,4 +51,5 @@ template<typename dst_t, typename src_t>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif // GGML_SYCL_CONVERT_HPP
|
||||
|
||||
@@ -9,23 +9,32 @@
|
||||
#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
|
||||
(ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
|
||||
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
|
||||
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
|
||||
const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
int src1_idx = i - offset;
|
||||
int oz = src1_idx / nb2;
|
||||
int oy = (src1_idx - (oz * nb2)) / nb1;
|
||||
int ox = src1_idx % nb1;
|
||||
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
|
||||
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
|
||||
} else {
|
||||
dst[i] = x[i];
|
||||
|
||||
int64_t src1_idx = i - offset;
|
||||
|
||||
int64_t tmp = src1_idx;
|
||||
const int64_t i13 = tmp / s13;
|
||||
tmp -= i13 * s13;
|
||||
const int64_t i12 = tmp / s12;
|
||||
tmp -= i12 * s12;
|
||||
const int64_t i11 = tmp / s11;
|
||||
tmp -= i11 * s11;
|
||||
const int64_t i10 = tmp;
|
||||
|
||||
float val = x[i];
|
||||
if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
|
||||
val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
|
||||
}
|
||||
dst[i] = val;
|
||||
}
|
||||
|
||||
/* Unary OP funcs */
|
||||
@@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
|
||||
|
||||
namespace ggml_sycl_detail {
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
||||
item_ct1);
|
||||
});
|
||||
const int64_t n_elements, const int64_t ne10, const int64_t ne11,
|
||||
const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
|
||||
const int64_t offset, queue_ptr stream) {
|
||||
const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
@@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
sycl::half * src0_p = (sycl::half *) src0_d;
|
||||
@@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
||||
std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
float * src0_p = (float *) src0_d;
|
||||
@@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
|
||||
|
||||
template<typename KernelInvoker, typename... Args>
|
||||
static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
@@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
|
||||
const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
auto data_pts = cast_data<sycl::half>(dst);
|
||||
@@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
||||
main_stream, std::forward<Args>(args)...);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
auto data_pts = cast_data<float>(dst);
|
||||
@@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const float * src1_d = (const float *) src1->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(dst));
|
||||
|
||||
ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
||||
const int64_t s1 = dst->op_params[0] / sizeof(float);
|
||||
const int64_t s2 = dst->op_params[1] / sizeof(float);
|
||||
const int64_t s3 = dst->op_params[2] / sizeof(float);
|
||||
const int64_t offset = dst->op_params[3] / sizeof(float);
|
||||
|
||||
ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
s1, s2, s3, offset, stream);
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_ROPE:
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
ggml_sycl_rope_back(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_sycl_im2col(ctx, dst);
|
||||
break;
|
||||
@@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return max_bias == 0.0f;
|
||||
}
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_ROPE_BACK:
|
||||
case GGML_OP_IM2COL:
|
||||
return true;
|
||||
case GGML_OP_UPSCALE:
|
||||
@@ -4872,8 +4876,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
k > 0 && k <= 32;
|
||||
}
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_ACC:
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
|
||||
case GGML_OP_PAD:
|
||||
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
|
||||
if (ggml_get_op_params_i32(op, 8) != 0) {
|
||||
|
||||
+65
-63
@@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int nthreads = item_ct1.get_local_range(2);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32(const float * x, float * dst, const int ncols,
|
||||
const int64_t stride_row, const int64_t stride_channel,
|
||||
const int64_t stride_sample, const float eps,
|
||||
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
|
||||
const int nrows = item_ct1.get_group_range(2);
|
||||
const int nchannels = item_ct1.get_group_range(1);
|
||||
|
||||
const int row = item_ct1.get_group(2);
|
||||
const int channel = item_ct1.get_group(1);
|
||||
const int sample = item_ct1.get_group(0);
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
|
||||
x += sample*stride_sample + channel*stride_channel + row*stride_row;
|
||||
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row * ncols + col];
|
||||
const float xi = x[col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
|
||||
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
/*
|
||||
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||
converged control flow. You may need to adjust the code.
|
||||
*/
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
size_t nreduce = nwarps / WARP_SIZE;
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||
tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size);
|
||||
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
|
||||
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||
dst[col] = scale * x[col];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
}
|
||||
}
|
||||
|
||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream, int device) {
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
template<int warp_size>
|
||||
static void l2_norm_f32_sycl(const float * x,
|
||||
float * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
const int nchannels,
|
||||
const int nsamples,
|
||||
const int64_t stride_row,
|
||||
const int64_t stride_channel,
|
||||
const int64_t stride_sample,
|
||||
const float eps,
|
||||
queue_ptr stream,
|
||||
int device) {
|
||||
const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
|
||||
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
const dpct::dim3 block_dims(warp_size, 1, 1);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
sycl::nd_range<3>(blocks_num * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
nullptr, warp_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||
assert(work_group_size % (warp_size * warp_size) == 0);
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size),
|
||||
cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
sycl::nd_range<3>(blocks_num * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
[[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
|
||||
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
|
||||
}
|
||||
|
||||
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const int64_t ne00 = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||
|
||||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
const size_t ts0 = ggml_type_size(src0->type);
|
||||
GGML_ASSERT(nb00 == ts0);
|
||||
const int64_t s01 = nb01 / ts0;
|
||||
const int64_t s02 = nb02 / ts0;
|
||||
const int64_t s03 = nb03 / ts0;
|
||||
|
||||
/*support both WARP_SIZE or WARP_32_SIZE in code
|
||||
choose by hardware for better performance
|
||||
*/
|
||||
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
|
||||
}
|
||||
|
||||
+447
-283
@@ -1,4 +1,5 @@
|
||||
#include "rope.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "ggml-sycl/common.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
@@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
|
||||
}
|
||||
|
||||
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
|
||||
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
||||
static void rope_yarn(
|
||||
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
|
||||
float * cos_theta, float * sin_theta) {
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
template <bool forward>
|
||||
static void rope_yarn(const float theta_extrap, const float freq_scale,
|
||||
const rope_corr_dims corr_dims, const int64_t i0,
|
||||
const float ext_factor, float mscale, float &cos_theta,
|
||||
float &sin_theta) {
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta = theta_interp;
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
float ramp_mix =
|
||||
rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
|
||||
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
|
||||
}
|
||||
*cos_theta = sycl::cos(theta) * mscale;
|
||||
*sin_theta = sycl::sin(theta) * mscale;
|
||||
cos_theta = sycl::cos(theta) * mscale;
|
||||
sin_theta = sycl::sin(theta) * mscale;
|
||||
if (!forward) {
|
||||
sin_theta *= -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const int64_t *row_indices, const int set_rows_stride) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int i = row * ne0 + i0;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0;
|
||||
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (set_rows_stride != 0) {
|
||||
idst = i1 * s1 + i0;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
const auto &store_coaelsced = [&](float x0, float x1) {
|
||||
if constexpr (std::is_same_v<float, D>) {
|
||||
sycl::float2 v = sycl::float2(x0, x1);
|
||||
ggml_sycl_memcpy_1<8>(dst + idst, &v);
|
||||
} else if constexpr (std::is_same_v<sycl::half, D>) {
|
||||
sycl::half2 v = sycl::half2(x0, x1);
|
||||
ggml_sycl_memcpy_1<4>(dst + idst, &v);
|
||||
}
|
||||
};
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
|
||||
store_coaelsced(x[ix + 0], x[ix + 1]);
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[i2 + 0];
|
||||
const float x1 = x[i2 + 1];
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + 1];
|
||||
|
||||
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
|
||||
store_coaelsced(x0 * cos_theta - x1 * sin_theta,
|
||||
x0 * sin_theta + x1 * cos_theta);
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
|
||||
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
|
||||
template <bool forward, bool has_ff, typename T, typename D>
|
||||
static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const int64_t *row_indices, const int set_rows_stride) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const int row0 = row % ne1;
|
||||
const int channel0 = row / ne1;
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
const int i = row * ne0 + i0 / 2;
|
||||
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (set_rows_stride != 0) {
|
||||
idst = i1 * s1 + i0 / 2;
|
||||
idst += row_indices[i2] * set_rows_stride;
|
||||
}
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
|
||||
dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]);
|
||||
dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[i2 + 0];
|
||||
const float x1 = x[i2 + n_dims / 2];
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims / 2];
|
||||
|
||||
dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||
dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta);
|
||||
dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta);
|
||||
}
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
|
||||
// get index pos
|
||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||
if (i0 >= ne0) {
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const mrope_sections sections, const bool is_imrope) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
if (i0 >= n_dims) {
|
||||
*reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
|
||||
dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
|
||||
dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sect_dims =
|
||||
sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
|
||||
float theta_base = 0.0;
|
||||
if (is_imrope) {
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
|
||||
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
|
||||
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
} else {
|
||||
if (sector < sections.v[0]) {
|
||||
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||
}
|
||||
else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
} else if (sector >= sec_w + sections.v[2]) {
|
||||
theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
|
||||
}
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims/2];
|
||||
|
||||
// store results in dst
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims / 2];
|
||||
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
|
||||
template <bool forward, bool has_ff, typename T>
|
||||
static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02,
|
||||
const int s03, const int s1, const int s2, const int s3,
|
||||
const int n_dims, const int32_t *pos,
|
||||
const float freq_scale, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float *freq_factors,
|
||||
const mrope_sections sections) {
|
||||
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
|
||||
const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
|
||||
|
||||
template <typename T, bool has_ff>
|
||||
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||
const sycl::nd_item<3> & item_ct1) {
|
||||
// get index pos
|
||||
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||
if (i0 >= ne0) {
|
||||
if (i0 >= ne00) {
|
||||
return;
|
||||
}
|
||||
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||
const int row_x = row_dst % ne1;
|
||||
const int channel_x = row_dst / ne1;
|
||||
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||
|
||||
const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const uint32_t i3 = row_dst / (ne01 * ne02);
|
||||
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
|
||||
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
|
||||
|
||||
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
|
||||
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
|
||||
|
||||
const int sect_dims = sections.v[0] + sections.v[1];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
const int sec_w = sections.v[1] + sections.v[0];
|
||||
const int sector = (i0 / 2) % sect_dims;
|
||||
|
||||
float theta_base = 0.0f;
|
||||
float theta_base = 0.0;
|
||||
if (sector < sections.v[0]) {
|
||||
const int p = sector;
|
||||
theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p);
|
||||
} else {
|
||||
theta_base = pos[i2] * dpct::pow(theta_scale, p);
|
||||
} else if (sector >= sections.v[0] && sector < sec_w) {
|
||||
const int p = sector - sections.v[0];
|
||||
theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
|
||||
theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
|
||||
}
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0,
|
||||
ext_factor, attn_factor, cos_theta, sin_theta);
|
||||
|
||||
const float x0 = x[ix + 0];
|
||||
const float x1 = x[ix + n_dims];
|
||||
|
||||
// store results in dst
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
||||
const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
|
||||
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float * freq_factors, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||
template <bool forward, typename T, typename D>
|
||||
static void
|
||||
rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const int64_t *row_indices,
|
||||
const int set_rows_stride, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
/*
|
||||
DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_norm<forward, false>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_norm<forward, true>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
|
||||
const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor, const float attn_factor,
|
||||
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||
template <bool forward, typename T, typename D>
|
||||
static void
|
||||
rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const int64_t *row_indices,
|
||||
const int set_rows_stride, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_neox<forward, false>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_neox<forward, true>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, row_indices, set_rows_stride);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||
const float freq_scale, const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||
template <bool forward, typename T>
|
||||
static void
|
||||
rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const mrope_sections sections,
|
||||
const bool is_imrope, dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
||||
// Add FP16 capability check if T could be sycl::half
|
||||
if constexpr (std::is_same_v<T, sycl::half>) {
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_multi<forward, false, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections, is_imrope);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_multi<forward, true, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections, is_imrope);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool forward, typename T>
|
||||
static void
|
||||
rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
|
||||
const int ne02, const int s01, const int s02, const int s03,
|
||||
const int s1, const int s2, const int s3, const int n_dims,
|
||||
const int nr, const int32_t *pos, const float freq_scale,
|
||||
const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims,
|
||||
const float *freq_factors, const mrope_sections sections,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_x =
|
||||
(ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
||||
const dpct::dim3 block_nums(nr, n_blocks_x, 1);
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||
|
||||
|
||||
// rope vision
|
||||
template <typename T>
|
||||
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||
const float freq_scale, const float freq_base, const float ext_factor,
|
||||
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||
const mrope_sections sections, queue_ptr stream) {
|
||||
GGML_ASSERT(ne0 % 2 == 0);
|
||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||
|
||||
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
||||
// Add FP16 capability check if T could be sycl::half
|
||||
if constexpr (std::is_same_v<T, sycl::half>) {
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_vision<forward, false, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
GGML_UNUSED(item_ct1);
|
||||
rope_vision<forward, true, T>(
|
||||
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
|
||||
pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, sections);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
template <bool forward>
|
||||
void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
|
||||
const ggml_tensor *set_rows = nullptr) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
const ggml_tensor *src2 = dst->src[2];
|
||||
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
const int64_t ne00 = dst->src[0]->ne[0]; // head dims
|
||||
const int64_t ne01 = dst->src[0]->ne[1]; // num heads
|
||||
const int64_t ne02 = dst->src[0]->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(dst->src[0]);
|
||||
const float *src0_d = (const float *)src0->data;
|
||||
const float *src1_d = (const float *)src1->data;
|
||||
|
||||
const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
|
||||
const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
|
||||
void *dst_d = dst->data;
|
||||
const int64_t *row_indices = nullptr;
|
||||
ggml_type dst_type = dst->type;
|
||||
int set_rows_stride = 0;
|
||||
|
||||
if (set_rows != nullptr) {
|
||||
GGML_ASSERT(forward);
|
||||
dst_d = set_rows->data;
|
||||
row_indices = (const int64_t *)set_rows->src[1]->data;
|
||||
dst_type = set_rows->type;
|
||||
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
|
||||
}
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == dst->type ||
|
||||
(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
|
||||
|
||||
const int64_t ne00 = src0->ne[0]; // head dims
|
||||
const int64_t ne01 = src0->ne[1]; // num heads
|
||||
const int64_t ne02 = src0->ne[2]; // num heads
|
||||
const int64_t nr = ggml_nrows(src0);
|
||||
|
||||
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
|
||||
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
|
||||
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
|
||||
|
||||
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
|
||||
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
|
||||
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
|
||||
|
||||
const int n_dims = ((int32_t *)dst->op_params)[1];
|
||||
const int mode = ((int32_t *)dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
|
||||
mrope_sections sections;
|
||||
|
||||
// RoPE alteration for extended context
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
@@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||
memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
|
||||
memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
@@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
||||
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
|
||||
sections.v[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
GGML_ASSERT(n_dims == ne00 / 2);
|
||||
}
|
||||
|
||||
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
||||
const int32_t *pos = (const int32_t *)src1_d;
|
||||
|
||||
const float * freq_factors = nullptr;
|
||||
if (dst->src[2] != nullptr) {
|
||||
freq_factors = (const float *) dst->src[2]->data;
|
||||
const float *freq_factors = nullptr;
|
||||
if (src2 != nullptr) {
|
||||
freq_factors = (const float *)src2->data;
|
||||
}
|
||||
|
||||
rope_corr_dims corr_dims;
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
|
||||
beta_slow, corr_dims.v);
|
||||
|
||||
// compute
|
||||
if (is_neox) {
|
||||
GGML_SYCL_DEBUG("%s: neox path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
||||
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_neox_sycl<forward, float, float>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl<forward, float, sycl::half>(
|
||||
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
||||
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_neox_sycl<forward, sycl::half, sycl::half>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else if (is_mrope && !is_vision) {
|
||||
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, is_imrope, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
is_imrope, main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d,
|
||||
ne00, ne01, ne02, s01, s02, s03, s1, s2,
|
||||
s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, is_imrope, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_multi_sycl<forward>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
sections, is_imrope, stream);
|
||||
} else {
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else if (is_vision) {
|
||||
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
|
||||
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||
freq_factors, sections, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
rope_vision_sycl<forward>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||
stream);
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
rope_vision_sycl<forward>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
sections, stream);
|
||||
} else {
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
} else {
|
||||
GGML_SYCL_DEBUG("%s: norm path\n", __func__);
|
||||
if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
|
||||
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
|
||||
n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
main_stream);
|
||||
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
|
||||
rope_norm_sycl<forward, float, float>(
|
||||
(const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
|
||||
s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
|
||||
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
|
||||
set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl<forward, float, sycl::half>(
|
||||
(const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
|
||||
s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
|
||||
rope_norm_sycl<forward, sycl::half, sycl::half>(
|
||||
(const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
|
||||
ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
|
||||
freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
|
||||
row_indices, set_rows_stride, stream);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
||||
ggml_sycl_op_rope(ctx, dst);
|
||||
|
||||
ggml_sycl_op_rope_impl<true>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
|
||||
ggml_sycl_op_rope_impl<false>(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
|
||||
ggml_tensor *set_rows) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
|
||||
ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define SYCL_ROPE_BLOCK_SIZE 256
|
||||
|
||||
void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
|
||||
|
||||
#endif // GGML_SYCL_ROPE_HPP
|
||||
|
||||
@@ -42,11 +42,20 @@
|
||||
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
|
||||
|
||||
// Matrix-vector multiplication parameters
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
|
||||
|
||||
// Must be multiple of 4 to work with vectorized paths, and must divide
|
||||
// mul_mat_vec wg size
|
||||
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256
|
||||
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
|
||||
#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256
|
||||
|
||||
// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
|
||||
// Requires at least two (and multiple of 2) k-quant blocks per tile
|
||||
#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512
|
||||
|
||||
// default size for legacy matrix multiplication
|
||||
#define WEBGPU_MUL_MAT_WG_SIZE 256
|
||||
@@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key {
|
||||
bool src_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
|
||||
src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -749,6 +759,36 @@ class ggml_webgpu_shader_lib {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "mul_mat_vec";
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
variant += "_" + src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// src1 type (vector)
|
||||
switch (context.src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
@@ -763,39 +803,21 @@ class ggml_webgpu_shader_lib {
|
||||
GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
|
||||
}
|
||||
|
||||
// src0 type (matrix row)
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("SRC0_INNER_TYPE=f32");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
defines.push_back("MUL_ACC_FLOAT");
|
||||
break;
|
||||
default:
|
||||
{
|
||||
// Quantized types: use helpers but accumulate in f16
|
||||
const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
|
||||
std::string src0_name = src0_traits->type_name;
|
||||
std::string type_upper = src0_name;
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
defines.push_back("BYTE_HELPERS");
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
|
||||
// For fast path we always dequantize from f16 inside the shader
|
||||
defines.push_back("SRC0_INNER_TYPE=f16");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// VEC/SCALAR controls
|
||||
defines.push_back(key.vectorized ? "VEC" : "SCALAR");
|
||||
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
|
||||
uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
|
||||
defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
|
||||
@@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_binary_pipeline_key key = {
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.type = context.dst->type,
|
||||
.op = context.dst->op,
|
||||
.inplace = context.inplace,
|
||||
.overlap = context.overlap,
|
||||
.src_overlap = context.src_overlap,
|
||||
};
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-webgpu-shader-lib.hpp"
|
||||
#include "pre_wgsl.hpp"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# include <emscripten/emscripten.h>
|
||||
@@ -20,12 +19,18 @@
|
||||
#include <condition_variable>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# include <iomanip>
|
||||
#endif
|
||||
#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
|
||||
# include <iostream>
|
||||
#endif
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
|
||||
@@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim
|
||||
#endif // GGML_WEBGPU_CPU_PROFILE
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
|
||||
# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32
|
||||
# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
|
||||
#endif
|
||||
|
||||
/* Constants */
|
||||
|
||||
#define WEBGPU_NUM_PARAM_BUFS 48u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u
|
||||
#define WEBGPU_NUM_PARAM_BUFS 96u
|
||||
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u
|
||||
#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
|
||||
// Maximum number of in-flight submissions per-thread, to avoid exhausting the
|
||||
// parameter buffer pool
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
|
||||
#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
|
||||
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
|
||||
#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16
|
||||
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
|
||||
|
||||
// For operations which process a row in parallel, this seems like a reasonable
|
||||
// default
|
||||
@@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
wgpu::BufferUsage usage,
|
||||
const char * label);
|
||||
|
||||
struct webgpu_pool_bufs {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
};
|
||||
|
||||
// Holds a pool of parameter buffers for WebGPU operations
|
||||
struct webgpu_buf_pool {
|
||||
std::vector<webgpu_pool_bufs> free;
|
||||
std::vector<wgpu::Buffer> free;
|
||||
|
||||
// The pool must be synchronized because
|
||||
// 1. The memset pool is shared globally by every ggml buffer,
|
||||
@@ -138,7 +137,6 @@ struct webgpu_buf_pool {
|
||||
size_t cur_pool_size;
|
||||
size_t max_pool_size;
|
||||
wgpu::Device device;
|
||||
wgpu::BufferUsage host_buf_usage;
|
||||
wgpu::BufferUsage dev_buf_usage;
|
||||
size_t buf_size;
|
||||
bool should_grow;
|
||||
@@ -147,53 +145,47 @@ struct webgpu_buf_pool {
|
||||
int num_bufs,
|
||||
size_t buf_size,
|
||||
wgpu::BufferUsage dev_buf_usage,
|
||||
wgpu::BufferUsage host_buf_usage,
|
||||
bool should_grow = false,
|
||||
size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->host_buf_usage = host_buf_usage;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
this->max_pool_size = max_pool_size;
|
||||
this->cur_pool_size = num_bufs;
|
||||
this->device = device;
|
||||
this->dev_buf_usage = dev_buf_usage;
|
||||
this->buf_size = buf_size;
|
||||
this->should_grow = should_grow;
|
||||
for (int i = 0; i < num_bufs; i++) {
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
free.push_back({ host_buf, dev_buf });
|
||||
free.push_back(dev_buf);
|
||||
}
|
||||
}
|
||||
|
||||
webgpu_pool_bufs alloc_bufs() {
|
||||
wgpu::Buffer alloc_bufs() {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
if (!free.empty()) {
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
// Try growing the pool if no free buffers
|
||||
if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
|
||||
cur_pool_size++;
|
||||
wgpu::Buffer host_buf;
|
||||
wgpu::Buffer dev_buf;
|
||||
ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
|
||||
ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
|
||||
|
||||
if (!(host_buf && dev_buf)) {
|
||||
if (!dev_buf) {
|
||||
GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
|
||||
}
|
||||
return webgpu_pool_bufs{ host_buf, dev_buf };
|
||||
return dev_buf;
|
||||
}
|
||||
cv.wait(lock, [this] { return !free.empty(); });
|
||||
webgpu_pool_bufs bufs = free.back();
|
||||
wgpu::Buffer buf = free.back();
|
||||
free.pop_back();
|
||||
return bufs;
|
||||
return buf;
|
||||
}
|
||||
|
||||
void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
|
||||
void free_bufs(std::vector<wgpu::Buffer> bufs) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
free.insert(free.end(), bufs.begin(), bufs.end());
|
||||
cv.notify_all();
|
||||
@@ -201,12 +193,9 @@ struct webgpu_buf_pool {
|
||||
|
||||
void cleanup() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for (auto & bufs : free) {
|
||||
if (bufs.host_buf) {
|
||||
bufs.host_buf.Destroy();
|
||||
}
|
||||
if (bufs.dev_buf) {
|
||||
bufs.dev_buf.Destroy();
|
||||
for (auto & buf : free) {
|
||||
if (buf) {
|
||||
buf.Destroy();
|
||||
}
|
||||
}
|
||||
free.clear();
|
||||
@@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool {
|
||||
#endif
|
||||
|
||||
struct webgpu_command {
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
uint32_t num_kernels;
|
||||
wgpu::CommandBuffer commands;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
webgpu_gpu_profile_bufs timestamp_query_bufs;
|
||||
std::string pipeline_name;
|
||||
@@ -358,6 +346,13 @@ struct webgpu_global_context_struct {
|
||||
|
||||
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
|
||||
|
||||
struct webgpu_submission {
|
||||
wgpu::FutureWaitInfo submit_done;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<wgpu::FutureWaitInfo> profile_futures;
|
||||
#endif
|
||||
};
|
||||
|
||||
// All the base objects needed to run operations on a WebGPU device
|
||||
struct webgpu_context_struct {
|
||||
// Points to global instances owned by ggml_backend_webgpu_reg_context
|
||||
@@ -366,7 +361,8 @@ struct webgpu_context_struct {
|
||||
std::unique_ptr<ggml_webgpu_shader_lib> shader_lib;
|
||||
|
||||
webgpu_buf_pool param_buf_pool;
|
||||
webgpu_buf_pool set_rows_error_buf_pool;
|
||||
wgpu::Buffer set_rows_dev_error_buf;
|
||||
wgpu::Buffer set_rows_host_error_buf;
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
@@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
|
||||
/** End WebGPU object initializations */
|
||||
|
||||
/** WebGPU Actions */
|
||||
static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
|
||||
static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
|
||||
switch (status) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
return true;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
if (allow_timeout) {
|
||||
return false;
|
||||
}
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
|
||||
return false;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
return false;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
static void ggml_backend_webgpu_erase_completed_futures(std::vector<wgpu::FutureWaitInfo> & futures) {
|
||||
futures.erase(std::remove_if(futures.begin(), futures.end(),
|
||||
[](const wgpu::FutureWaitInfo & info) { return info.completed; }),
|
||||
futures.end());
|
||||
}
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx,
|
||||
std::vector<wgpu::FutureWaitInfo> & futures,
|
||||
bool block) {
|
||||
if (futures.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t timeout_ms = block ? UINT64_MAX : 0;
|
||||
while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
|
||||
if (waitStatus == wgpu::WaitStatus::Error) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
if (futures[0].completed) {
|
||||
futures.erase(futures.begin());
|
||||
} else {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
|
||||
ggml_backend_webgpu_erase_completed_futures(futures);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Wait for the queue to finish processing all submitted work
|
||||
static void ggml_backend_webgpu_wait(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_submission> & subs,
|
||||
bool block = true) {
|
||||
// If we have too many in-flight submissions, wait on the oldest one first.
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
|
||||
if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
|
||||
#endif
|
||||
subs.erase(subs.begin());
|
||||
}
|
||||
}
|
||||
|
||||
if (futures.empty()) {
|
||||
if (subs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (block) {
|
||||
while (!futures.empty()) {
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
for (auto & sub : subs) {
|
||||
while (!sub.submit_done.completed) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus);
|
||||
}
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
|
||||
#endif
|
||||
}
|
||||
subs.clear();
|
||||
} else {
|
||||
// Poll once and return
|
||||
auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
|
||||
switch (waitStatus) {
|
||||
case wgpu::WaitStatus::Success:
|
||||
// WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
|
||||
erase_completed(futures);
|
||||
break;
|
||||
case wgpu::WaitStatus::TimedOut:
|
||||
break;
|
||||
case wgpu::WaitStatus::Error:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
|
||||
break;
|
||||
default:
|
||||
GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
|
||||
break;
|
||||
// Poll each submit future once and remove completed submissions.
|
||||
for (auto sub = subs.begin(); sub != subs.end();) {
|
||||
auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
|
||||
ggml_backend_webgpu_handle_wait_status(waitStatus, true);
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
|
||||
if (sub->submit_done.completed && sub->profile_futures.empty()) {
|
||||
#else
|
||||
if (sub->submit_done.completed) {
|
||||
#endif
|
||||
sub = subs.erase(sub);
|
||||
} else {
|
||||
++sub;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
webgpu_global_context ctx,
|
||||
std::vector<webgpu_command> commands,
|
||||
webgpu_buf_pool & param_buf_pool,
|
||||
webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
|
||||
static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx,
|
||||
std::vector<webgpu_command> & commands,
|
||||
webgpu_buf_pool & param_buf_pool) {
|
||||
std::vector<wgpu::CommandBuffer> command_buffers;
|
||||
std::vector<webgpu_pool_bufs> params_bufs;
|
||||
std::vector<webgpu_pool_bufs> set_rows_error_bufs;
|
||||
std::vector<wgpu::Buffer> params_bufs;
|
||||
webgpu_submission submission;
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
|
||||
#endif
|
||||
@@ -569,14 +601,9 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
for (const auto & command : commands) {
|
||||
command_buffers.push_back(command.commands);
|
||||
params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
|
||||
if (command.set_rows_error_bufs) {
|
||||
set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
|
||||
}
|
||||
}
|
||||
ctx->queue.Submit(command_buffers.size(), command_buffers.data());
|
||||
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
|
||||
wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
|
||||
wgpu::CallbackMode::AllowSpontaneous,
|
||||
[¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
|
||||
@@ -586,27 +613,7 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
// Free the staged buffers
|
||||
param_buf_pool.free_bufs(params_bufs);
|
||||
});
|
||||
futures.push_back({ p_f });
|
||||
|
||||
for (const auto & bufs : set_rows_error_bufs) {
|
||||
wgpu::Future f = bufs.host_buf.MapAsync(
|
||||
wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
|
||||
[set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
if (status != wgpu::MapAsyncStatus::Success) {
|
||||
GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
|
||||
} else {
|
||||
const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
if (set_rows_error_buf_pool) {
|
||||
set_rows_error_buf_pool->free_bufs({ bufs });
|
||||
}
|
||||
}
|
||||
});
|
||||
futures.push_back({ f });
|
||||
}
|
||||
submission.submit_done = { p_f };
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
for (const auto & command : commands) {
|
||||
@@ -623,14 +630,14 @@ static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
|
||||
// WebGPU timestamps are in ns; convert to ms
|
||||
double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
|
||||
ctx->shader_gpu_time_ms[label] += elapsed_ms;
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
}
|
||||
// We can't unmap in here due to WebGPU reentrancy limitations.
|
||||
ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
|
||||
});
|
||||
futures.push_back({ f });
|
||||
submission.profile_futures.push_back({ f });
|
||||
}
|
||||
#endif
|
||||
return futures;
|
||||
return submission;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
@@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
const std::vector<webgpu_pipeline> & pipelines,
|
||||
const std::vector<std::vector<uint32_t>> & params_list,
|
||||
const std::vector<std::vector<wgpu::BindGroupEntry>> & bind_group_entries_list,
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list,
|
||||
const std::optional<webgpu_pool_bufs> & set_rows_error_bufs = std::nullopt) {
|
||||
const std::vector<std::pair<uint32_t, uint32_t>> & workgroups_list) {
|
||||
GGML_ASSERT(pipelines.size() == params_list.size());
|
||||
GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
|
||||
GGML_ASSERT(pipelines.size() == workgroups_list.size());
|
||||
|
||||
std::vector<webgpu_pool_bufs> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
std::vector<wgpu::Buffer> params_bufs_list;
|
||||
std::vector<wgpu::BindGroup> bind_groups;
|
||||
|
||||
for (size_t i = 0; i < pipelines.size(); i++) {
|
||||
webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0,
|
||||
params_bufs.host_buf.GetSize());
|
||||
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
|
||||
for (size_t j = 0; j < params_list[i].size(); j++) {
|
||||
_params[j] = params_list[i][j];
|
||||
}
|
||||
params_bufs.host_buf.Unmap();
|
||||
wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = bind_group_entries_list[i];
|
||||
uint32_t params_binding_num = entries.size();
|
||||
entries.push_back({ .binding = params_binding_num,
|
||||
.buffer = params_bufs.dev_buf,
|
||||
.offset = 0,
|
||||
.size = params_bufs.dev_buf.GetSize() });
|
||||
entries.push_back(
|
||||
{ .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
|
||||
|
||||
wgpu::BindGroupDescriptor bind_group_desc;
|
||||
bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0);
|
||||
@@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
}
|
||||
|
||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||
for (const auto & params_bufs : params_bufs_list) {
|
||||
encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
|
||||
}
|
||||
|
||||
// If there are SET_ROWS operations in this submission, copy their error
|
||||
// buffers to the host.
|
||||
if (set_rows_error_bufs) {
|
||||
encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
|
||||
set_rows_error_bufs->host_buf.GetSize());
|
||||
for (size_t i = 0; i < params_bufs_list.size(); i++) {
|
||||
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
@@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
|
||||
webgpu_command result = {};
|
||||
result.commands = commands;
|
||||
result.params_bufs = params_bufs_list;
|
||||
result.set_rows_error_bufs = set_rows_error_bufs;
|
||||
result.num_kernels = pipelines.size();
|
||||
#ifdef GGML_WEBGPU_GPU_PROFILE
|
||||
result.timestamp_query_bufs = ts_bufs;
|
||||
@@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &
|
||||
std::vector<uint32_t> params,
|
||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||
uint32_t wg_x,
|
||||
uint32_t wg_y = 1,
|
||||
std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
|
||||
uint32_t wg_y = 1) {
|
||||
return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
|
||||
{
|
||||
pipeline
|
||||
},
|
||||
{ params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs);
|
||||
{ std::move(params) }, { std::move(bind_group_entries) },
|
||||
{ { wg_x, wg_y } });
|
||||
}
|
||||
|
||||
static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
@@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
|
||||
webgpu_command command =
|
||||
ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
|
||||
auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
|
||||
ggml_backend_webgpu_wait(ctx, futures);
|
||||
std::vector<webgpu_command> commands = { command };
|
||||
std::vector<webgpu_submission> sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
|
||||
ggml_backend_webgpu_wait(ctx, sub);
|
||||
}
|
||||
|
||||
/** End WebGPU Actions */
|
||||
@@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
|
||||
std::cout << "\nggml_webgpu: gpu breakdown:\n";
|
||||
for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
|
||||
double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
|
||||
std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
|
||||
<< pct << "%)\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -978,14 +968,6 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
|
||||
auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
|
||||
if (decisions->i64_idx) {
|
||||
error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
|
||||
if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
|
||||
error_bufs->host_buf.Unmap();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
|
||||
@@ -1018,8 +1000,10 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
};
|
||||
|
||||
if (decisions->i64_idx) {
|
||||
entries.push_back(
|
||||
{ .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() });
|
||||
entries.push_back({ .binding = 3,
|
||||
.buffer = ctx->set_rows_dev_error_buf,
|
||||
.offset = 0,
|
||||
.size = ctx->set_rows_dev_error_buf.GetSize() });
|
||||
}
|
||||
|
||||
uint32_t threads;
|
||||
@@ -1029,8 +1013,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
|
||||
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
|
||||
}
|
||||
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
|
||||
error_bufs);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
|
||||
}
|
||||
|
||||
// Workgroup size is a common constant
|
||||
@@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
use_fast = (src0->type == GGML_TYPE_F16);
|
||||
break;
|
||||
case GGML_TYPE_F32:
|
||||
// TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q8_1:
|
||||
case GGML_TYPE_Q6_K:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
// we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
|
||||
use_fast = !is_vec;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
if (use_fast && is_vec) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
// Fast-path tiled/subgroup calculations
|
||||
uint32_t wg_m, wg_n;
|
||||
uint32_t wg_m;
|
||||
uint32_t wg_n;
|
||||
if (decisions->use_subgroup_matrix) {
|
||||
uint32_t wg_m_sg_tile =
|
||||
decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
|
||||
@@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
|
||||
} else { // legacy
|
||||
auto decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_size = decisions->wg_size;
|
||||
uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
|
||||
compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
|
||||
@@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t dim = (uint32_t) dst->op_params[0];
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
@@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t)src0->ne[dim]
|
||||
(uint32_t) src0->ne[dim]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0)
|
||||
},
|
||||
{
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1)
|
||||
},
|
||||
{
|
||||
.binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst)
|
||||
}
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
||||
{ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {
|
||||
@@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
|
||||
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
@@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
||||
|
||||
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
||||
@@ -2172,19 +2169,12 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
|
||||
case GGML_OP_UNARY:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_CLAMP:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_FILL:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_LOG:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQR:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SQRT:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_SIN:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_COS:
|
||||
return ggml_webgpu_unary_op(ctx, src0, node);
|
||||
case GGML_OP_PAD:
|
||||
@@ -2192,7 +2182,6 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
||||
case GGML_OP_ARGMAX:
|
||||
return ggml_webgpu_argmax(ctx, src0, node);
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
case GGML_OP_TOP_K:
|
||||
// we reuse the same argsort implementation for top_k
|
||||
return ggml_webgpu_argsort(ctx, src0, node);
|
||||
@@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
|
||||
|
||||
WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
|
||||
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<wgpu::FutureWaitInfo> futures;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
std::vector<webgpu_command> commands;
|
||||
std::vector<webgpu_submission> subs;
|
||||
uint32_t num_batched_kernels = 0;
|
||||
bool contains_set_rows = false;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
|
||||
contains_set_rows = true;
|
||||
}
|
||||
if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
|
||||
commands.push_back(*cmd);
|
||||
num_batched_kernels += cmd.value().num_kernels;
|
||||
}
|
||||
|
||||
if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
|
||||
num_batched_kernels = 0;
|
||||
std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
|
||||
ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
|
||||
num_batched_kernels = 0;
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
// Process events and check for completed submissions
|
||||
ctx->global_ctx->instance.ProcessEvents();
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
|
||||
commands.clear();
|
||||
}
|
||||
}
|
||||
if (!commands.empty()) {
|
||||
auto new_futures =
|
||||
ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
|
||||
futures.insert(futures.end(), new_futures.begin(), new_futures.end());
|
||||
subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
|
||||
commands.clear();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, futures);
|
||||
// If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
|
||||
if (contains_set_rows) {
|
||||
wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
|
||||
encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
wgpu::CommandBuffer set_rows_commands = encoder.Finish();
|
||||
ctx->global_ctx->queue.Submit(1, &set_rows_commands);
|
||||
ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
|
||||
ctx->set_rows_host_error_buf.GetSize());
|
||||
const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
|
||||
if (*error_data) {
|
||||
GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
|
||||
}
|
||||
ctx->set_rows_host_error_buf.Unmap();
|
||||
}
|
||||
|
||||
ggml_backend_webgpu_wait(ctx->global_ctx, subs);
|
||||
WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
@@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
||||
webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
|
||||
webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
|
||||
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
|
||||
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
|
||||
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
|
||||
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
|
||||
@@ -11,7 +11,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#endif
|
||||
#endif // VEC
|
||||
|
||||
#ifdef SCALAR
|
||||
#define VEC_SIZE 1
|
||||
@@ -23,7 +23,7 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#endif
|
||||
#endif // SCALAR
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
@@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
@@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
||||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
@@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // INIT_SRC0_SHMEM_Q4_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_lo = f16(q_byte & 0xF) * d + m;
|
||||
let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_lo;
|
||||
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
|
||||
|
||||
shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight
|
||||
shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 1u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
|
||||
const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
let d = src0[scale_idx];
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j+=2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d + m;
|
||||
shmem[shmem_idx + j * 2 + k] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 42u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 40u];
|
||||
let dmin = src0[scale_idx + 41u];
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
let pos_in_32 = k_in_block % 32u;
|
||||
|
||||
let q_b_idx = (block_of_32 / 4u) * 32u;
|
||||
let shift = (block_of_32 % 4u) * 2u;
|
||||
let k = (pos_in_32 / 16u) * 16u;
|
||||
let l = pos_in_32 % 16u;
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
|
||||
let sc_0 = src0[scale_idx + 2u * (is / 4u)];
|
||||
let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
|
||||
let sc_packed = bitcast<u32>(vec2(sc_0, sc_1));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q2_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 55u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx + 54u];
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
let kmask2: u32 = 0x0f0f0f0fu;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
let scale_0 = src0[scale_idx + 48u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
|
||||
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
let hmask_0 = src0[scale_idx + (2u*i)];
|
||||
let hmask_1 = src0[scale_idx + (2u*i) + 1u];
|
||||
hmask_vals[i] = bitcast<u32>(vec2(hmask_0, hmask_1));
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
let qs_0 = src0[scale_idx + 16u + (2u*i)];
|
||||
let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
|
||||
qs_vals[i] = bitcast<u32>(vec2(qs_0, qs_1));
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
let pos_in_half = k_in_block % 128u; // 0-127
|
||||
let shift_group = pos_in_half / 32u; // 0-3
|
||||
let pos_in_32 = pos_in_half % 32u; // 0-31
|
||||
let k_group = pos_in_32 / 16u; // 0 or 1
|
||||
let l = pos_in_32 % 16u; // 0-15
|
||||
|
||||
let q_b_idx = half * 32u; // 0 or 32
|
||||
let shift = shift_group * 2u; // 0, 2, 4, 6
|
||||
let k = k_group * 16u; // 0 or 16
|
||||
let is = k_in_block / 16u; // 0-15
|
||||
|
||||
// m increments every 32 elements across entire 256 element block
|
||||
let m_shift = k_in_block / 32u; // 0-7
|
||||
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
|
||||
|
||||
let sc = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let dl = d * (f16(sc) - 32.0);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = (f16(qs_val) - f16(hm)) * dl;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q3_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 72u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||
// Inner loop over 2 shifts per group
|
||||
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 88u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let d = src0[scale_idx];
|
||||
let dmin = src0[scale_idx + 1u];
|
||||
|
||||
// Load packed scales
|
||||
var scale_vals: array<u32, 3>;
|
||||
for (var i: u32 = 0u; i < 3u; i++) {
|
||||
let scale_0 = src0[scale_idx + 2u + (2u*i)];
|
||||
let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
|
||||
scale_vals[i] = bitcast<u32>(vec2(scale_0, scale_1));
|
||||
}
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||
// But u increments EVERY 32 elements (after each l loop)
|
||||
let group_of_64 = k_in_block / 64u; // 0-3
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
|
||||
let u_shift = k_in_block / 32u; // 0-7
|
||||
let u: u32 = 1u << u_shift;
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
|
||||
let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
|
||||
let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
|
||||
let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
|
||||
let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh_0, qh_1));
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
|
||||
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q5_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let scale_idx = src0_idx * F16_PER_BLOCK;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
let quarter = pos_in_half / 32u;
|
||||
let l = pos_in_half % 32u;
|
||||
|
||||
let ql_b_idx = half * 64u;
|
||||
let qh_b_idx = half * 32u;
|
||||
let sc_b_idx = half * 8u;
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13_word = ql13_flat / 4u;
|
||||
let ql13 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql13_word],
|
||||
src0[scale_idx + 2u * ql13_word + 1u]
|
||||
));
|
||||
let ql13_b = get_byte(ql13, ql13_flat % 4u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24_word = ql24_flat / 4u;
|
||||
let ql24 = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 2u * ql24_word],
|
||||
src0[scale_idx + 2u * ql24_word + 1u]
|
||||
));
|
||||
let ql24_b = get_byte(ql24, ql24_flat % 4u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh_word = qh_flat / 4u;
|
||||
let qh = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 64u + 2u * qh_word],
|
||||
src0[scale_idx + 64u + 2u * qh_word + 1u]
|
||||
));
|
||||
let qh_b = get_byte(qh, qh_flat % 4u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
|
||||
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
|
||||
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc_word = sc_idx / 4u;
|
||||
let sc = bitcast<u32>(vec2(
|
||||
src0[scale_idx + 96u + 2u * sc_word],
|
||||
src0[scale_idx + 96u + 2u * sc_word + 1u]
|
||||
));
|
||||
let sc_val = get_byte_i32(sc, sc_idx % 4u);
|
||||
|
||||
let d = src0[scale_idx + 104u];
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
q_val = q1;
|
||||
} else if (quarter == 1u) {
|
||||
q_val = q2;
|
||||
} else if (quarter == 2u) {
|
||||
q_val = q3;
|
||||
} else {
|
||||
q_val = q4;
|
||||
}
|
||||
|
||||
shmem[elem_idx] = d * f16(sc_val) * q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
|
||||
@@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 {
|
||||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "common_decls.tmpl"
|
||||
@@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 10u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = f32(src0[scale_idx + 1u]);
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
|
||||
let q_lo = f32(q_byte & 0xF) * d + m;
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q5_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 11u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let qh0 = src0[scale_idx + 1u];
|
||||
let qh1 = src0[scale_idx + 2u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q5_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 12u;
|
||||
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
let qh0 = src0[scale_idx + 2u];
|
||||
let qh1 = src0[scale_idx + 3u];
|
||||
let qh_packed = bitcast<u32>(vec2(qh0, qh1));
|
||||
|
||||
for (var j = 0u; j < 2; j++) {
|
||||
let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
|
||||
let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
|
||||
let j_adjusted = j + (block_offset / 2u);
|
||||
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
|
||||
let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
|
||||
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
|
||||
let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
|
||||
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
|
||||
|
||||
local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
|
||||
local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 17u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 1 + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d;
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#ifdef MUL_ACC_Q8_1
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
const F16_PER_BLOCK = 18u;
|
||||
const WEIGHTS_PER_F16 = 2u;
|
||||
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
|
||||
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
|
||||
// each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
|
||||
let d = f32(src0[scale_idx]);
|
||||
let m = src0[scale_idx + 1u];
|
||||
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = src0[scale_idx + 2u + block_offset + j];
|
||||
let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k: u32 = 0; k < 4; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * d + f32(m);
|
||||
local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q6_K
|
||||
|
||||
const BLOCK_SIZE = 256u;
|
||||
const F16_PER_BLOCK = 105u;
|
||||
|
||||
fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
|
||||
let aligned = byte_offset & ~3u;
|
||||
let idx = bbase + aligned / 2u;
|
||||
return bitcast<u32>(vec2(src0[idx], src0[idx + 1u]));
|
||||
}
|
||||
|
||||
fn byte_of(v: u32, b: u32) -> u32 {
|
||||
return (v >> (b * 8u)) & 0xFFu;
|
||||
}
|
||||
|
||||
fn sbyte_of(v: u32, b: u32) -> i32 {
|
||||
let raw = i32((v >> (b * 8u)) & 0xFFu);
|
||||
return select(raw, raw - 256, raw >= 128);
|
||||
}
|
||||
|
||||
fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
let tid = tig / 2u;
|
||||
let ix = tig % 2u;
|
||||
let ip = tid / 8u;
|
||||
let il = tid % 8u;
|
||||
let l0 = 4u * il;
|
||||
let is = 8u * ip + l0 / 16u;
|
||||
|
||||
let y_offset = 128u * ip + l0;
|
||||
let q_offset_l = 64u * ip + l0;
|
||||
let q_offset_h = 32u * ip + l0;
|
||||
|
||||
let nb = tile_size / BLOCK_SIZE;
|
||||
let k_block_start = k_outer / BLOCK_SIZE;
|
||||
|
||||
// Aligned scale byte position (is can be odd)
|
||||
let sc_base_byte = 192u + (is & ~3u);
|
||||
let sc_byte_pos = is & 3u;
|
||||
|
||||
var local_sum = 0.0;
|
||||
|
||||
for (var i = ix; i < nb; i += 2u) {
|
||||
let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
|
||||
|
||||
let d_raw = load_u32_at(bbase, 208u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_raw)[0]);
|
||||
|
||||
let ql1_u32 = load_u32_at(bbase, q_offset_l);
|
||||
let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u);
|
||||
let qh_u32 = load_u32_at(bbase, 128u + q_offset_h);
|
||||
let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
|
||||
let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
|
||||
|
||||
let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
|
||||
let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
|
||||
let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
|
||||
let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
|
||||
|
||||
var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
|
||||
for (var l = 0u; l < 4u; l++) {
|
||||
let y_base = i * BLOCK_SIZE + y_offset + l;
|
||||
let yl0 = f32(shared_vector[y_base]);
|
||||
let yl1 = f32(shared_vector[y_base + 32u]);
|
||||
let yl2 = f32(shared_vector[y_base + 64u]);
|
||||
let yl3 = f32(shared_vector[y_base + 96u]);
|
||||
|
||||
let q1b = byte_of(ql1_u32, l);
|
||||
let q2b = byte_of(ql2_u32, l);
|
||||
let qhb = byte_of(qh_u32, l);
|
||||
|
||||
let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
|
||||
let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
|
||||
let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32);
|
||||
let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32);
|
||||
|
||||
sums[0] += yl0 * dq0;
|
||||
sums[1] += yl1 * dq1;
|
||||
sums[2] += yl2 * dq2;
|
||||
sums[3] += yl3 * dq3;
|
||||
}
|
||||
|
||||
local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
|
||||
sums[2] * f32(sc4) + sums[3] * f32(sc6));
|
||||
}
|
||||
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
@@ -191,4 +478,3 @@ fn main(
|
||||
dst[dst_idx / VEC_SIZE] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1087,6 +1087,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
|
||||
@@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
throw std::runtime_error(std::string("expecting an int at ") + pos);
|
||||
}
|
||||
const char * int_end = parse_int(pos);
|
||||
uint64_t min_times = std::stoul(std::string(pos, int_end - pos));
|
||||
uint64_t min_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
|
||||
uint64_t max_times = UINT64_MAX; // default: no max limit
|
||||
@@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence(
|
||||
|
||||
if (is_digit_char(*pos)) {
|
||||
const char * int_end = parse_int(pos);
|
||||
max_times = std::stoul(std::string(pos, int_end - pos));
|
||||
max_times = std::stoull(std::string(pos, int_end - pos));
|
||||
pos = parse_space(int_end, is_nested);
|
||||
}
|
||||
|
||||
|
||||
+2
-2
@@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
const bool last = (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token
|
||||
);
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
@@ -2552,7 +2552,7 @@ void llm_graph_context::build_pooling(
|
||||
}
|
||||
|
||||
// softmax for qwen3 reranker
|
||||
if (arch == LLM_ARCH_QWEN3) {
|
||||
if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) {
|
||||
cur = ggml_soft_max(ctx0, cur);
|
||||
}
|
||||
} break;
|
||||
|
||||
+469
-285
File diff suppressed because it is too large
Load Diff
@@ -168,8 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % (n_group*d_state) == 0);
|
||||
GGML_ASSERT(d_inner % n_head == 0);
|
||||
GGML_ASSERT(d_inner % d_state == 0);
|
||||
GGML_ASSERT(d_inner % n_group == 0);
|
||||
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
@@ -569,6 +569,55 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"array with empty items",
|
||||
R"""({
|
||||
"type": "array",
|
||||
"items": {}
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"array with empty items and prefixItems",
|
||||
R"""({
|
||||
"type": "array",
|
||||
"items": {},
|
||||
"prefixItems": { "type": "string" }
|
||||
})""",
|
||||
R"""(
|
||||
array ::= "[" space ( value ("," space value)* )? "]" space
|
||||
boolean ::= ("true" | "false") space
|
||||
char ::= [^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})
|
||||
decimal-part ::= [0-9]{1,16}
|
||||
integral-part ::= [0] | [1-9] [0-9]{0,15}
|
||||
item ::= object
|
||||
null ::= "null" space
|
||||
number ::= ("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space
|
||||
object ::= "{" space ( string ":" space value ("," space string ":" space value)* )? "}" space
|
||||
root ::= "[" space (item ("," space item)*)? "]" space
|
||||
space ::= | " " | "\n"{1,2} [ \t]{0,20}
|
||||
string ::= "\"" char* "\"" space
|
||||
value ::= object | array | string | number | boolean | null
|
||||
)"""
|
||||
});
|
||||
|
||||
test({
|
||||
SUCCESS,
|
||||
"number",
|
||||
|
||||
+19
-31
@@ -18,6 +18,13 @@
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
|
||||
// result of parsing --tensor-type option
|
||||
// (changes to this struct must be reflected in src/llama-quant.cpp)
|
||||
struct tensor_type_option {
|
||||
std::string name;
|
||||
ggml_type type = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
struct quant_option {
|
||||
std::string name;
|
||||
llama_ftype ftype;
|
||||
@@ -65,12 +72,6 @@ static const std::vector<quant_option> QUANT_OPTIONS = {
|
||||
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },
|
||||
};
|
||||
|
||||
// Quantization types. Changes to this struct must be replicated in llama-quantize.cpp
|
||||
struct tensor_quantization {
|
||||
std::string name;
|
||||
ggml_type quant = GGML_TYPE_COUNT;
|
||||
};
|
||||
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix.file";
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset";
|
||||
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
|
||||
@@ -413,7 +414,7 @@ static ggml_type parse_ggml_type(const char * arg) {
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) {
|
||||
static bool parse_tensor_type(const char * data, std::vector<tensor_type_option> & tensor_type) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr) {
|
||||
printf("\n%s: malformed tensor type '%s'\n\n", __func__, data);
|
||||
@@ -433,11 +434,11 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||
std::string tn(data, tn_len);
|
||||
std::transform(tn.begin(), tn.end(), tn.begin(), tolower);
|
||||
sep++;
|
||||
tensor_quantization tqz;
|
||||
tqz.name = tn;
|
||||
tqz.quant = parse_ggml_type(sep);
|
||||
tensor_type.emplace_back(std::move(tqz));
|
||||
if (tqz.quant == GGML_TYPE_COUNT) {
|
||||
tensor_type_option tensor_type_opt;
|
||||
tensor_type_opt.name = tn;
|
||||
tensor_type_opt.type = parse_ggml_type(sep);
|
||||
tensor_type.emplace_back(std::move(tensor_type_opt));
|
||||
if (tensor_type_opt.type == GGML_TYPE_COUNT) {
|
||||
printf("\n%s: invalid quantization type '%s'\n\n", __func__, sep);
|
||||
return false;
|
||||
}
|
||||
@@ -445,7 +446,7 @@ static bool parse_tensor_type(const char * data, std::vector<tensor_quantization
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool parse_tensor_type_file(const char * filename, std::vector<tensor_quantization> & tensor_type) {
|
||||
static bool parse_tensor_type_file(const char * filename, std::vector<tensor_type_option> & tensor_type) {
|
||||
std::ifstream file(filename);
|
||||
if (!file) {
|
||||
printf("\n%s: failed to open file '%s': %s\n\n", __func__, filename, std::strerror(errno));
|
||||
@@ -501,7 +502,7 @@ int main(int argc, char ** argv) {
|
||||
std::string imatrix_file;
|
||||
std::vector<std::string> included_weights, excluded_weights;
|
||||
std::vector<llama_model_kv_override> kv_overrides;
|
||||
std::vector<tensor_quantization> tensor_types;
|
||||
std::vector<tensor_type_option> tensor_type_opts;
|
||||
std::vector<int> prune_layers;
|
||||
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
@@ -526,11 +527,11 @@ int main(int argc, char ** argv) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--tensor-type") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_type_opts)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--tensor-type-file") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type_file(argv[++arg_idx], tensor_types)) {
|
||||
if (arg_idx == argc-1 || !parse_tensor_type_file(argv[++arg_idx], tensor_type_opts)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--prune-layers") == 0) {
|
||||
@@ -624,8 +625,8 @@ int main(int argc, char ** argv) {
|
||||
kv_overrides.back().key[0] = 0;
|
||||
params.kv_overrides = &kv_overrides;
|
||||
}
|
||||
if (!tensor_types.empty()) {
|
||||
params.tensor_types = &tensor_types;
|
||||
if (!tensor_type_opts.empty()) {
|
||||
params.tensor_types = &tensor_type_opts;
|
||||
}
|
||||
if (!prune_layers.empty()) {
|
||||
params.prune_layers = &prune_layers;
|
||||
@@ -692,18 +693,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.dry_run &&
|
||||
(
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S ||
|
||||
params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M
|
||||
) && imatrix_data.empty()) {
|
||||
fprintf(stderr, "\n==========================================================================================================\n");
|
||||
fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
|
||||
fprintf(stderr, "==========================================================================================================\n\n\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!params.dry_run) {
|
||||
if (std::error_code ec; std::filesystem::equivalent(fname_inp, fname_out, ec)) {
|
||||
fprintf(stderr, "%s: error: input and output files are the same: '%s'\n", __func__, fname_inp.c_str());
|
||||
@@ -753,4 +742,3 @@ int main(int argc, char ** argv) {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@@ -2530,9 +2530,24 @@ private:
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
||||
// process the last few tokens of the prompt separately in order to allow for a checkpoint to be created.
|
||||
const int n_last = std::min(n_batch, 512);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
break;
|
||||
// create checkpoints that many tokens before the end of the prompt:
|
||||
// - 4 + n_ubatch
|
||||
// - 4
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
||||
{
|
||||
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
||||
|
||||
bool should_break = false;
|
||||
for (int offset : checkpoint_offsets) {
|
||||
const int n_last = std::min(n_batch, offset);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
should_break = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (should_break) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2554,18 +2569,27 @@ private:
|
||||
slot.init_sampler();
|
||||
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
if (do_checkpoint) {
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
if (slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch) {
|
||||
// near the end of the prompt
|
||||
do_checkpoint = do_checkpoint && true;
|
||||
} else {
|
||||
// only do non-end checkpoints if the "checkpoint every n tokens" option is set
|
||||
do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
llama_pos last_checkpoint = 0;
|
||||
if (!slot.prompt.checkpoints.empty()) {
|
||||
last_checkpoint = slot.prompt.checkpoints.back().n_tokens;
|
||||
}
|
||||
|
||||
do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt;
|
||||
|
||||
if (do_checkpoint) {
|
||||
SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user