[SYCL] Support Q4_1, Q5_0, Q5_1 in Flash-attention (#23812)

* support Q4_1, Q5_0, Q5_1

* update ut case
This commit is contained in:
Neo Zhang
2026-06-01 14:53:53 +08:00
committed by GitHub
parent 4162522688
commit a51142497a
3 changed files with 1016 additions and 1015 deletions
+1012 -1012
View File
File diff suppressed because it is too large Load Diff
+1
View File
@@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental;
#define GGML_COMMON_IMPL_SYCL
#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention.
/* suppress warning spam */
#pragma clang diagnostic push
+3 -3
View File
@@ -1031,7 +1031,7 @@ void launch_fattn(
auto KV_max_ptr_ct1 = KV_max.ptr;
cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_mask_to_KV_max<ncols1, warp_size>(
mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
@@ -1149,7 +1149,7 @@ void launch_fattn(
auto K_ne_ct6 = K->ne[2];
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
@@ -1169,7 +1169,7 @@ void launch_fattn(
auto KQV_data_ct2 = (float *) KQV->data;
cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
[=](sycl::nd_item<3> item_ct1) {
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
GGML_UNUSED(item_ct1);
flash_attn_combine_results<DV>(
dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,