mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
[SYCL] Support Q4_1, Q5_0, Q5_1 in Flash-attention (#23812)
* support Q4_1, Q5_0, Q5_1 * update ut case
This commit is contained in:
+1012
-1012
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental;
|
||||
#define GGML_COMMON_IMPL_SYCL
|
||||
#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
|
||||
#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
|
||||
#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention.
|
||||
|
||||
/* suppress warning spam */
|
||||
#pragma clang diagnostic push
|
||||
|
||||
@@ -1031,7 +1031,7 @@ void launch_fattn(
|
||||
auto KV_max_ptr_ct1 = KV_max.ptr;
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
GGML_UNUSED(item_ct1);
|
||||
flash_attn_mask_to_KV_max<ncols1, warp_size>(
|
||||
mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
|
||||
@@ -1149,7 +1149,7 @@ void launch_fattn(
|
||||
auto K_ne_ct6 = K->ne[2];
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
GGML_UNUSED(item_ct1);
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
|
||||
Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
|
||||
@@ -1169,7 +1169,7 @@ void launch_fattn(
|
||||
auto KQV_data_ct2 = (float *) KQV->data;
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
|
||||
GGML_UNUSED(item_ct1);
|
||||
flash_attn_combine_results<DV>(
|
||||
dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
|
||||
|
||||
Reference in New Issue
Block a user