ggml-webgpu: FlashAttention refactor + standardize quantization support (#23834)

* Start work on flash_attn refactor

* Refactor

* Split k/v quantization

* Refactor and abstract quantization logic for flash_attn and mul_mat

* Add quantization support to tile path

* formatting

* Move to functions, add a check
This commit is contained in:
Reese Levine
2026-06-03 22:05:04 -07:00
committed by GitHub
parent 3c7450cee1
commit e8c54893f2
11 changed files with 983 additions and 947 deletions
+5 -2
View File
@@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
# Find all WGSL files
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
# Find all WGSL sources
file(GLOB WGSL_SHADER_FILES
"${SHADER_DIR}/*.wgsl"
"${SHADER_DIR}/*.tmpl"
)
# Generate the header using a Python script
add_custom_command(
+349 -314
View File
@@ -18,6 +18,9 @@
#define GGML_WEBGPU_F32_SIZE_BYTES 4
#define GGML_WEBGPU_I32_SIZE_BYTES 4
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
#define GGML_WEBGPU_KV_SEQ_PAD 256u
@@ -546,16 +549,10 @@ struct ggml_webgpu_unary_pipeline_key_hash {
/** FlashAttention */
enum ggml_webgpu_flash_attn_path : uint32_t {
GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u,
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u,
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u,
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u,
};
struct ggml_webgpu_flash_attn_pipeline_key {
struct ggml_webgpu_flash_attn_common_pipeline_key {
ggml_type q_type;
ggml_type kv_type;
ggml_type k_type;
ggml_type v_type;
ggml_type dst_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
@@ -564,93 +561,224 @@ struct ggml_webgpu_flash_attn_pipeline_key {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t path;
bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
}
};
inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
const ggml_webgpu_flash_attn_common_pipeline_key & key) {
ggml_webgpu_hash_combine(seed, key.q_type);
ggml_webgpu_hash_combine(seed, key.k_type);
ggml_webgpu_hash_combine(seed, key.v_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.kv_overlap);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
}
struct ggml_webgpu_flash_attn_vec_pipeline_key {
ggml_webgpu_flash_attn_common_pipeline_key common;
bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
};
struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
return seed;
}
};
struct ggml_webgpu_flash_attn_pipeline_key {
ggml_webgpu_flash_attn_common_pipeline_key common;
bool use_sg_matrix;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap && path == other.path;
return common == other.common && use_sg_matrix == other.use_sg_matrix;
}
};
struct ggml_webgpu_flash_attn_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_type);
ggml_webgpu_hash_combine(seed, key.kv_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.kv_direct);
ggml_webgpu_hash_combine(seed, key.kv_overlap);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
ggml_webgpu_hash_combine(seed, key.path);
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
return seed;
}
};
struct ggml_webgpu_flash_attn_vec_decisions {
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
struct ggml_webgpu_flash_attn_decisions {
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
bool kv_direct = false;
bool kv_overlap = false;
bool use_sg_matrix = false;
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
key.head_dim_qk != key.head_dim_v) {
return 1u;
}
switch (key.head_dim_qk) {
case 64:
case 192:
case 576:
return 2u;
case 96:
return 4u;
default:
return 1u;
}
inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
constexpr uintptr_t ptr_base_addr = 0x1000u;
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
}
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
const ggml_webgpu_shader_lib_context & context,
const ggml_webgpu_flash_attn_decisions & decisions) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;
bool kv_direct = false;
if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
kv_direct_align = context.sg_mat_k;
}
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
(context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) &&
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
const uint32_t offset_elems =
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
}
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
const ggml_tensor * V,
size_t storage_offset_alignment) {
return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
}
inline bool ggml_webgpu_flash_attn_kv_direct(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
}
inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
const ggml_webgpu_shader_lib_context & context,
uint32_t kv_direct_align) {
ggml_webgpu_flash_attn_common_pipeline_key key = {};
key.q_type = context.src0->type;
key.k_type = context.src1->type;
key.v_type = context.src2->type;
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
return key;
}
inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
const ggml_webgpu_flash_attn_common_pipeline_key & key,
std::string & variant,
uint32_t q_tile,
uint32_t kv_tile,
uint32_t wg_size) {
std::vector<std::string> defines;
switch (key.k_type) {
case GGML_TYPE_F32:
defines.push_back("K_F32");
break;
case GGML_TYPE_F16:
defines.push_back("K_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("K_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("K_Q8_0");
break;
default:
GGML_ABORT("Unsupported K type for flash attention shader");
}
variant += std::string("_k") + ggml_type_name(key.k_type);
switch (key.v_type) {
case GGML_TYPE_F32:
defines.push_back("V_F32");
break;
case GGML_TYPE_F16:
defines.push_back("V_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("V_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("V_Q8_0");
break;
default:
GGML_ABORT("Unsupported V type for flash attention shader");
}
variant += std::string("_v") + ggml_type_name(key.v_type);
switch (key.q_type) {
case GGML_TYPE_F32:
defines.push_back("Q_F32");
break;
case GGML_TYPE_F16:
defines.push_back("Q_F16");
break;
default:
GGML_ABORT("Unsupported Q type for flash attention shader");
}
variant += std::string("_q") + ggml_type_name(key.q_type);
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
break;
default:
GGML_ABORT("Unsupported dst type for flash attention shader");
}
variant += std::string("_dst") + ggml_type_name(key.dst_type);
if (key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
if (key.kv_overlap) {
defines.push_back("KV_OVERLAP");
variant += "_kv_overlap";
}
ggml_webgpu_flash_attn_pipeline_key key = {};
key.q_type = context.src0->type;
key.kv_type = context.src1->type;
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = kv_direct;
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = has_mask;
key.has_sinks = has_sinks;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
key.path = decisions.path;
return key;
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
defines.push_back("U32_DEQUANT_HELPERS");
}
return defines;
}
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
@@ -688,29 +816,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
}
};
// This is exposed because it's necessary in supports_op
// Note: this will slightly overestimate memory usage for vec path
// since row_max and exp_sum shmem are not needed.
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct,
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
bool kv_direct) {
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
size_t f16_elems = 0;
size_t f32_elems = 0;
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
f32_elems += head_dim_qk; // q_shmem
if (!kv_direct) {
f32_elems += kv_tile * max_head_dim; // kv_shmem
}
f32_elems += head_dim_v; // o_shmem
if (has_mask) {
f32_elems += kv_tile; // mask_shmem
}
f32_elems += kv_tile; // inter_shmem
return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
}
f32_elems += q_tile * head_dim_qk; // q_shmem
if (!kv_direct) {
f32_elems += kv_tile * max_head_dim; // kv_shmem
@@ -725,25 +842,20 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
}
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
const ggml_webgpu_flash_attn_pipeline_key & key) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_granularity = std::max(1u, context.sg_mat_n);
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
kv_granularity = 1u;
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
q_tile = 1u;
kv_granularity = 8u;
}
const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v,
key.has_mask, key.kv_direct, key.path);
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
uint32_t q_tile,
uint32_t kv_granularity,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct) {
const size_t base_q_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
if (limit_bytes <= base_q_bytes) {
return 0;
}
const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v,
key.has_mask, key.kv_direct, key.path);
const size_t one_kv_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
if (bytes_per_kv == 0) {
return 0;
@@ -752,105 +864,32 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
}
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
const ggml_webgpu_shader_lib_context & context,
size_t storage_offset_alignment) {
ggml_webgpu_flash_attn_decisions decisions = {};
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
const auto * K = context.src1;
const auto * V = context.src2;
GGML_ASSERT(K != nullptr);
GGML_ASSERT(V != nullptr);
inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct) {
const uint32_t max_kv_tile =
ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
GGML_ASSERT(max_kv_tile > 0);
const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t {
constexpr uintptr_t ptr_base_addr = 0x1000u;
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
};
const uint32_t k_offset_elems =
(uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
const uint32_t v_offset_elems =
(uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) &&
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
const bool kv_vec_type_supported =
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const uint32_t kv_vec_head_align =
K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type);
const bool kv_vec_head_dims_aligned =
context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0;
// Compile with enough invocations to cover the largest reported subgroup.
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned &&
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
(context.src2->type == K->type);
const bool tile_can_dispatch_all_q_rows =
context.max_subgroup_size > 0 &&
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
context.src0->ne[0] % context.sg_mat_k == 0 &&
context.src2->ne[0] % context.sg_mat_n == 0;
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
tile_can_dispatch_all_q_rows && !use_vec;
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
return decisions;
}
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
decisions.kv_direct = key.kv_direct;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
// invalidate if even the smallest kv_tile doesn't fit in shared memory
if (max_kv_tile == 0) {
decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
return decisions;
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
decisions.q_tile = 1u;
decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile));
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
decisions.wg_size = context.max_subgroup_size;
if (decisions.kv_direct) {
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
decisions.kv_tile -= 8u;
}
}
return decisions;
}
decisions.q_tile =
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
std::min(64u, max_kv_tile) :
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
std::min(std::max(1u, context.max_wg_size),
std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) :
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
if (decisions.kv_tile == 0) {
return decisions;
}
if (decisions.kv_direct) {
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
decisions.kv_tile -=
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n;
uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
if (kv_direct) {
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= 1u;
}
}
return decisions;
return kv_tile;
}
inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
uint32_t sg_mat_k,
uint32_t sg_mat_n,
const ggml_tensor * Q,
const ggml_tensor * V) {
return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
}
/** Matrix Multiplication **/
@@ -1123,6 +1162,10 @@ class ggml_webgpu_shader_lib {
concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_vec_pipeline_key_hash>
flash_attn_vec_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
@@ -1835,10 +1878,10 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
key.use_mmvq =
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
@@ -1971,11 +2014,11 @@ class ggml_webgpu_shader_lib {
ggml_webgpu_mul_mat_pipeline_key key = {};
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
key.use_subgroup_matrix = context.supports_subgroup_matrix;
auto it = mul_mat_fast_pipelines.find(key);
if (it != mul_mat_fast_pipelines.end()) {
@@ -2148,10 +2191,10 @@ class ggml_webgpu_shader_lib {
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.src0->ne[2];
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
auto it = mul_mat_id_pipelines.find(key);
if (it != mul_mat_id_pipelines.end()) {
@@ -2271,10 +2314,10 @@ class ggml_webgpu_shader_lib {
key.src0_type = context.src0->type;
key.src1_type = context.src1->type;
key.n_experts = context.src0->ne[2];
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
1 :
0;
1 :
0;
auto it = mul_mat_id_vec_pipelines.find(key);
if (it != mul_mat_id_vec_pipelines.end()) {
@@ -2664,119 +2707,62 @@ class ggml_webgpu_shader_lib {
return repeat_pipelines[key];
}
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context,
size_t storage_offset_alignment) {
const ggml_webgpu_flash_attn_decisions decisions =
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE);
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
auto it = flash_attn_pipelines.find(key);
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
ggml_webgpu_flash_attn_decisions decisions = {};
decisions.use_sg_matrix = can_use_subgroup_matrix;
decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
ggml_webgpu_flash_attn_pipeline_key key = {};
key.common =
ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
key.use_sg_matrix = decisions.use_sg_matrix;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
GGML_ASSERT(max_kv_tile > 0);
decisions.kv_tile = decisions.use_sg_matrix ?
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
decisions.wg_size =
decisions.use_sg_matrix ?
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
if (key.common.kv_direct) {
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
}
}
auto it = flash_attn_pipelines.find(key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" :
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" :
"flash_attn";
switch (key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
case GGML_TYPE_F16:
defines.push_back("KV_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("KV_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("KV_Q8_0");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(key.kv_type);
switch (key.q_type) {
case GGML_TYPE_F32:
defines.push_back("Q_F32");
break;
case GGML_TYPE_F16:
defines.push_back("Q_F16");
break;
default:
GGML_ABORT("Unsupported Q type for flash attention shader");
}
variant += std::string("_q") + ggml_type_name(key.q_type);
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
break;
default:
GGML_ABORT("Unsupported dst type for flash attention shader");
}
variant += std::string("_dst") + ggml_type_name(key.dst_type);
if (key.has_mask) {
defines.push_back("MASK");
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
defines.push_back("BLK");
variant += "_mask_blk";
} else {
variant += "_mask";
}
}
if (key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
if (key.kv_overlap) {
defines.push_back("KV_OVERLAP");
variant += "_kv_overlap";
}
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
const char * shader_src = wgsl_flash_attn;
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
defines.push_back("KV_GRANULARITY=8");
defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u");
shader_src = wgsl_flash_attn_vec_split;
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
decisions.kv_tile, decisions.wg_size);
const char * shader_src = nullptr;
if (!key.use_sg_matrix) {
shader_src = wgsl_flash_attn_tile;
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
std::to_string(context.max_subgroup_size);
} else {
shader_src = wgsl_flash_attn;
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
}
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
pipeline_decisions->kv_overlap = key.kv_overlap;
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
pipeline.context = pipeline_decisions;
@@ -2784,6 +2770,55 @@ class ggml_webgpu_shader_lib {
return flash_attn_pipelines[key];
}
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_flash_attn_vec_pipeline_key key = {};
key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
auto it = flash_attn_vec_pipelines.find(key);
if (it != flash_attn_vec_pipelines.end()) {
return it->second;
}
ggml_webgpu_flash_attn_vec_decisions decisions = {};
decisions.kv_tile =
ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
decisions.wg_size = context.max_subgroup_size;
std::string variant = "flash_attn_vec";
std::vector<std::string> defines =
ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
if (key.common.has_mask) {
defines.push_back("BLK");
variant.resize(variant.size() - (sizeof("_mask") - 1));
variant += "_mask_blk";
}
uint32_t vec_ne = 1u;
if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
key.common.head_dim_qk == key.common.head_dim_v) {
switch (key.common.head_dim_qk) {
case 64:
case 192:
case 576:
vec_ne = 2u;
break;
case 96:
vec_ne = 4u;
break;
default:
break;
}
}
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
pipeline.context = pipeline_decisions;
flash_attn_vec_pipelines[key] = pipeline;
return flash_attn_vec_pipelines[key];
}
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
key.kv_tile = kv_tile;
+219 -187
View File
@@ -1755,13 +1755,50 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst) {
struct ggml_webgpu_flash_attn_op {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
std::vector<uint32_t> params;
std::vector<wgpu::BindGroupEntry> entries;
size_t kv_bind_offset = 0;
size_t kv_bind_size = 0;
bool has_mask = false;
bool has_sinks = false;
bool kv_overlap = false;
};
static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx,
const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V) {
const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) ||
ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment);
const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) ||
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
const bool k_vec_type_supported =
K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
const bool v_vec_type_supported =
V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0;
const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ?
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(K->type);
const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ?
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(V->type);
const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0;
return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) &&
kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned &&
v_float_vec4_aligned;
}
static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst) {
float scale = ggml_get_op_params_f32(dst, 0);
float max_bias = ggml_get_op_params_f32(dst, 1);
float logit_softcap = ggml_get_op_params_f32(dst, 2);
@@ -1772,47 +1809,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = Q;
shader_lib_ctx.src1 = K;
shader_lib_ctx.src2 = V;
shader_lib_ctx.src3 = mask;
shader_lib_ctx.src4 = sinks;
shader_lib_ctx.dst = dst;
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
const int has_mask = (mask != nullptr);
const int has_sinks = (sinks != nullptr);
const bool kv_overlap = decisions->kv_overlap;
ggml_webgpu_flash_attn_op op = {};
op.shader_lib_ctx.src0 = Q;
op.shader_lib_ctx.src1 = K;
op.shader_lib_ctx.src2 = V;
op.shader_lib_ctx.src3 = mask;
op.shader_lib_ctx.src4 = sinks;
op.shader_lib_ctx.dst = dst;
op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
size_t kv_bind_offset = 0;
size_t kv_bind_size = 0;
if (kv_overlap) {
op.has_mask = mask != nullptr;
op.has_sinks = sinks != nullptr;
op.kv_overlap = ggml_webgpu_tensor_overlap(K, V);
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
if (op.kv_overlap) {
const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
kv_bind_offset = merged_range.offset;
kv_bind_size = merged_range.size;
op.kv_bind_offset = merged_range.offset;
op.kv_bind_size = merged_range.size;
offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
}
std::vector<uint32_t> params = {
op.params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
offset_k,
offset_v,
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) Q->ne[2], // number of heads
(uint32_t) Q->ne[1], // sequence length (Q)
@@ -1826,7 +1859,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap)
ggml_webgpu_u32_from_f32(max_bias),
@@ -1834,32 +1867,56 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_webgpu_u32_from_f32(n_head_log2),
ggml_webgpu_u32_from_f32(m0),
ggml_webgpu_u32_from_f32(m1)
};
std::vector<wgpu::BindGroupEntry> entries = {
op.entries = {
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
};
if (kv_overlap) {
entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
if (op.kv_overlap) {
op.entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size));
} else {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
}
uint32_t binding_index = kv_overlap ? 2u : 3u;
if (has_mask) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
uint32_t binding_index = op.kv_overlap ? 2u : 3u;
if (op.has_mask) {
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
}
if (has_sinks) {
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks));
if (op.has_sinks) {
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks));
}
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
return op;
}
static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) {
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) kv_tile;
while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) {
nwg <<= 1;
}
return std::min(nwg, vec_nwg_cap);
}
static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) {
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3];
return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x);
}
static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst,
ggml_webgpu_flash_attn_op op) {
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
wgpu::Buffer blk_buf = {};
uint64_t blk_size_bytes = 0;
@@ -1868,13 +1925,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
uint32_t blk_batch_count = 0;
const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]);
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
const bool use_vec_reduce = nwg > 1u;
GGML_ASSERT(nrows <= UINT32_MAX);
@@ -1910,7 +1962,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
webgpu_pipeline blk_pipeline;
std::vector<uint32_t> blk_params;
std::vector<wgpu::BindGroupEntry> blk_entries;
if (has_mask) {
if (op.has_mask) {
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
blk_nblk1 = (uint32_t) Q->ne[1];
blk_buf = ggml_webgpu_tensor_buf(dst);
@@ -1918,7 +1970,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx;
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
blk_params = {
@@ -1938,8 +1990,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
}
std::vector<uint32_t> split_params = params;
if (has_mask) {
std::vector<uint32_t> split_params = op.params;
if (op.has_mask) {
split_params.push_back(0u); // blk_base
split_params.push_back(blk_nblk0); // blk_nblk0
split_params.push_back(blk_nblk1); // blk_nblk1
@@ -1952,9 +2004,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
ggml_webgpu_tensor_binding_size(ctx, Q)),
};
if (kv_overlap) {
if (op.kv_overlap) {
split_entries.push_back(
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size));
} else {
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
ggml_webgpu_tensor_align_offset(ctx, K),
@@ -1963,18 +2015,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_webgpu_tensor_align_offset(ctx, V),
ggml_webgpu_tensor_binding_size(ctx, V)));
}
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
if (has_mask) {
uint32_t split_binding_index = op.kv_overlap ? 2u : 3u;
if (op.has_mask) {
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
ggml_webgpu_tensor_align_offset(ctx, mask),
ggml_webgpu_tensor_binding_size(ctx, mask)));
}
if (has_sinks) {
if (op.has_sinks) {
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks),
ggml_webgpu_tensor_align_offset(ctx, sinks),
ggml_webgpu_tensor_binding_size(ctx, sinks)));
}
if (has_mask) {
if (op.has_mask) {
split_entries.push_back(
ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes));
}
@@ -1993,7 +2045,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
reduce_sg_size,
(uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size,
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx;
reduce_shader_ctx.max_wg_size = reduce_wg_size;
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
@@ -2020,7 +2072,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
std::vector<webgpu_dispatch_desc> dispatches;
if (has_mask) {
if (op.has_mask) {
dispatches.push_back({
blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count }
});
@@ -2037,6 +2089,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
return ggml_backend_webgpu_build_multi(ctx, dispatches);
}
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
ggml_tensor * V,
ggml_tensor * mask,
ggml_tensor * sinks,
ggml_tensor * dst) {
ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst);
if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) {
return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op));
}
return ggml_webgpu_flash_attn_direct(ctx, op);
}
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
@@ -3553,70 +3619,43 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
break;
case GGML_OP_FLASH_ATTN_EXT:
{
const ggml_tensor * Q = tensor->src[0];
const ggml_tensor * K = tensor->src[1];
const ggml_tensor * V = tensor->src[2];
const ggml_tensor * mask = tensor->src[3];
const ggml_tensor * sinks = tensor->src[4];
if (Q && K && V) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = const_cast<ggml_tensor *>(Q);
shader_lib_ctx.src1 = const_cast<ggml_tensor *>(K);
shader_lib_ctx.src2 = const_cast<ggml_tensor *>(V);
shader_lib_ctx.src3 = const_cast<ggml_tensor *>(mask);
shader_lib_ctx.src4 = const_cast<ggml_tensor *>(sinks);
shader_lib_ctx.dst = const_cast<ggml_tensor *>(tensor);
shader_lib_ctx.max_wg_size =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix =
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
const ggml_tensor * Q = tensor->src[0];
const ggml_tensor * K = tensor->src[1];
const ggml_tensor * V = tensor->src[2];
const ggml_tensor * mask = tensor->src[3];
const auto & capabilities = ctx->webgpu_global_ctx->capabilities;
if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) {
const bool kv_direct =
ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile(
capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0],
mask != nullptr, kv_direct);
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const uint32_t vec_nwg_cap = capabilities.min_subgroup_size;
uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]);
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const uint32_t kv_tile = decisions.kv_tile;
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
uint32_t nwg = 1u;
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
nwg <<= 1;
}
nwg = std::min(nwg, vec_nwg_cap);
const size_t align =
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
if (nwg > 1u) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
const size_t tmp_size_bytes = ROUNDUP_POW2(
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += tmp_size_bytes + align;
} else {
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
}
if (mask != nullptr) {
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
const size_t blk_size_bytes =
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += blk_size_bytes + align;
}
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
const size_t align = capabilities.limits.minStorageBufferOffsetAlignment;
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
if (nwg > 1u) {
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float),
WEBGPU_STORAGE_BUF_BINDING_MULT);
res += tmp_size_bytes + align;
} else {
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
}
if (mask != nullptr) {
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
const size_t blk_size_bytes =
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
res += blk_size_bytes + align;
}
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
}
}
break;
@@ -4139,70 +4178,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
case GGML_OP_FLASH_ATTN_EXT:
{
// conservative support checks for whether the more resource-intensive shader paths
// can be used, to avoid cases where flash_attn is assigned to the CPU later on
supports_op = src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
src2->type == src1->type && op->type == GGML_TYPE_F32;
(src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 ||
src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) &&
op->type == GGML_TYPE_F32;
if (!supports_op) {
break;
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
shader_lib_ctx.src0 = src0;
shader_lib_ctx.src1 = src1;
shader_lib_ctx.src2 = src2;
shader_lib_ctx.src3 = op->src[3];
shader_lib_ctx.src4 = op->src[4];
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
const bool has_mask = op->src[3] != nullptr;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type &&
!ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) {
supports_op = false;
break;
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
const auto & capabilities = ctx->webgpu_global_ctx->capabilities;
const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment;
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}
// subgroup matrix path requirements
const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2);
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
// tile path requirements
const bool float_vec4_aligned =
((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) ||
ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) &&
((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) ||
ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment));
const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ?
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(src1->type);
const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ?
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
(uint32_t) ggml_blck_size(src2->type);
const bool tile_kv_head_dims_aligned =
src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0;
const bool tile_can_dispatch_all_q_rows =
capabilities.limits.maxComputeInvocationsPerWorkgroup >=
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size;
const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned &&
tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows;
if (!use_subgroup_matrix && !use_tile) {
supports_op = false;
break;
}
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
const uint32_t q_tile =
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
supports_op = max_kv_tile > 0;
break;
}
case GGML_OP_RMS_NORM:
+37 -7
View File
@@ -37,15 +37,33 @@ static std::string trim(const std::string & s) {
}
static std::string trim_value(std::istream & is) {
std::string str;
std::getline(is, str);
return trim(str);
std::ostringstream ss;
ss << is.rdbuf();
return trim(ss.str());
}
static bool isIdentChar(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
}
static bool endsWithContinuation(const std::string & line) {
size_t i = line.size();
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
i--;
}
return i > 0 && line[i - 1] == '\\';
}
static void stripContinuation(std::string & line) {
size_t i = line.size();
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
i--;
}
if (i > 0 && line[i - 1] == '\\') {
line.erase(i - 1);
}
}
static std::string expandMacrosRecursiveInternal(const std::string & line,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting);
@@ -595,19 +613,31 @@ class Preprocessor {
std::string line;
while (std::getline(in, line)) {
std::string t = trim(line);
std::string logical = line;
std::string t = trim(logical);
if (!t.empty() && t[0] == '#') {
while (endsWithContinuation(logical)) {
stripContinuation(logical);
if (!std::getline(in, line)) {
break;
}
logical += "\n";
logical += line;
}
t = trim(logical);
}
if (!t.empty() && t[0] == '#') {
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
if (mode == DirectiveMode::IncludesOnly && !handled) {
out << line << "\n";
out << logical << "\n";
}
} else {
if (mode == DirectiveMode::IncludesOnly) {
out << line << "\n";
out << logical << "\n";
} else if (condActive(cond)) {
// Expand macros in the line before outputting
std::string expanded = expandMacrosRecursive(line, macros);
std::string expanded = expandMacrosRecursive(logical, macros);
out << expanded << "\n";
}
}
+62 -209
View File
@@ -4,12 +4,23 @@ enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;
#ifdef KV_F32
#define KV_TYPE f32
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
#define KV_TYPE u32
#define BYTE_HELPERS
#include "common_decls.tmpl"
#ifdef K_F32
#define K_TYPE f32
#elif defined(K_Q4_0) || defined(K_Q8_0)
#define K_TYPE u32
#else
#define KV_TYPE f16
#define K_TYPE f16
#endif
#ifdef V_F32
#define V_TYPE f32
#elif defined(V_Q4_0) || defined(V_Q8_0)
#define V_TYPE u32
#else
#define V_TYPE f16
#endif
// Default values
@@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix;
// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
// Quantization constants/helpers
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
// number of quantized elements processed per thread
#if defined(KV_Q4_0)
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
// Ok not to put these in a define block, compiler will remove if unused
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
#if defined(KV_Q4_0) || defined(KV_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#endif
struct Params {
offset_q: u32,
offset_k: u32,
@@ -139,11 +80,11 @@ struct Params {
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
#ifdef KV_OVERLAP
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#define V K
#else
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
#endif
#if defined(MASK) && defined(SINKS)
@@ -238,10 +179,47 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32
return (*buf)[scalar_index >> 2u];
}
fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> {
return (*buf)[scalar_index >> 2u];
}
#ifndef KV_DIRECT
#define QUANT_SHMEM kv_shmem
#define QUANT_OUT_TYPE f16
#include "quant_inner_loops.tmpl"
#include "flash_attn_quant_staging.tmpl"
#if !defined(K_Q4_0) && !defined(K_Q8_0)
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
kv_shmem[elem_idx] = f16(select(
0.0,
K[global_k_row_offset + k_col],
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
}
}
#endif
#if !defined(V_Q4_0) && !defined(V_Q8_0)
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
kv_shmem[elem_idx] = f16(select(
0.0,
V[global_v_row_offset + v_col],
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
}
}
#endif
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@@ -311,77 +289,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
// clear inter_shmem to ensure zero-initialized accumulators
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = 0.0;
}
// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
kv_shmem[elem_idx] = f16(select(
0.0,
K[global_k_row_offset + k_col],
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
}
#ifndef KV_DIRECT
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
#endif
workgroupBarrier();
@@ -520,71 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
kv_shmem[elem_idx] = f16(select(
0.0,
V[global_v_row_offset + v_col],
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
}
#ifndef KV_DIRECT
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
#endif
workgroupBarrier();
@@ -0,0 +1,124 @@
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
#if defined(K_Q4_0)
#define K_NQ 16
#define K_BLOCK_SIZE_BYTES 18u
#define K_BYTES_PER_THREAD 8u
#define K_BYTES_PER_INNER_LOOP 4u
#elif defined(K_Q8_0)
#define K_NQ 16
#define K_BLOCK_SIZE_BYTES 34u
#define K_BYTES_PER_THREAD 16u
#define K_BYTES_PER_INNER_LOOP 4u
#endif
#if defined(V_Q4_0)
#define V_NQ 16
#define V_BLOCK_SIZE_BYTES 18u
#define V_BYTES_PER_THREAD 8u
#define V_BYTES_PER_INNER_LOOP 4u
#elif defined(V_Q8_0)
#define V_NQ 16
#define V_BLOCK_SIZE_BYTES 34u
#define V_BYTES_PER_THREAD 16u
#define V_BYTES_PER_INNER_LOOP 4u
#endif
#if defined(K_Q4_0) || defined(K_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
#endif
#if defined(V_Q4_0) || defined(V_Q8_0)
fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}
fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}
#endif
fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#if defined(K_Q4_0) || defined(K_Q8_0)
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
let thread_byte_offset = block_offset * K_BYTES_PER_THREAD;
let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset;
for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) {
let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP;
let q_packed = load_k_u32_at(q_byte_offset);
#if defined(K_Q4_0)
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP);
#elif defined(K_Q8_0)
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP);
#endif
}
}
}
#endif
#if defined(V_Q4_0) || defined(V_Q8_0)
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
let thread_byte_offset = block_offset * V_BYTES_PER_THREAD;
let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset;
for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) {
let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP;
let q_packed = load_v_u32_at(q_byte_offset);
#if defined(V_Q4_0)
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP);
#elif defined(V_Q8_0)
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP);
#endif
}
}
}
#endif
@@ -1,16 +1,29 @@
enable f16;
enable subgroups;
#define BYTE_HELPERS
#include "common_decls.tmpl"
#ifdef Q_F16
#define Q_TYPE f16
#else
#define Q_TYPE f32
#endif
#ifdef KV_F32
#define KV_TYPE f32
#ifdef K_F32
#define K_TYPE f32
#elif defined(K_Q4_0) || defined(K_Q8_0)
#define K_TYPE u32
#else
#define KV_TYPE f16
#define K_TYPE f16
#endif
#ifdef V_F32
#define V_TYPE f32
#elif defined(V_Q4_0) || defined(V_Q8_0)
#define V_TYPE u32
#else
#define V_TYPE f16
#endif
#ifdef DST_F16
@@ -21,7 +34,6 @@ enable subgroups;
#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define KV_STAGE_STRIDE 64
#define Q_TILE 4
#define KV_TILE 64
#define WG_SIZE 128
@@ -64,11 +76,23 @@ struct Params {
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
#ifdef KV_OVERLAP
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
#if defined(K_Q4_0) || defined(K_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
#endif
#define V K
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
#if defined(K_Q4_0) || defined(K_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
#endif
#if defined(V_Q4_0) || defined(V_Q8_0)
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
#else
@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>;
#endif
#endif
#if defined(MASK) && defined(SINKS)
@@ -121,10 +145,50 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>;
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>;
#define QUANT_SHMEM kv_shmem
#define QUANT_OUT_TYPE f16
#include "quant_inner_loops.tmpl"
#include "flash_attn_quant_staging.tmpl"
#if !defined(K_Q4_0) && !defined(K_Q8_0)
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
let kv_local = vec_idx_local / Q_CHUNKS;
let chunk = vec_idx_local % Q_CHUNKS;
let global_k_row = kv_tile + kv_local;
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
let k4 = K[k_vec_index];
let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u;
kv_shmem[kv_off + 0u] = f16(k4.x);
kv_shmem[kv_off + 1u] = f16(k4.y);
kv_shmem[kv_off + 2u] = f16(k4.z);
kv_shmem[kv_off + 3u] = f16(k4.w);
}
}
#endif
#if !defined(V_Q4_0) && !defined(V_Q8_0)
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
let kv_local = vec_idx_local / V_CHUNKS;
let chunk = vec_idx_local % V_CHUNKS;
let global_v_row = kv_tile + kv_local;
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
let v4 = V[v_vec_index];
let kv_off = kv_local * HEAD_DIM_V + chunk * 4u;
kv_shmem[kv_off + 0u] = f16(v4.x);
kv_shmem[kv_off + 1u] = f16(v4.y);
kv_shmem[kv_off + 2u] = f16(v4.z);
kv_shmem[kv_off + 3u] = f16(v4.w);
}
}
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@@ -206,18 +270,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
local_scores[slot] = FLOAT_MIN;
}
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
let kv_local = vec_idx_local / Q_CHUNKS;
let chunk = vec_idx_local % Q_CHUNKS;
let global_k_row = kv_tile + kv_local;
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
let k4 = K[k_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = KV_TYPE(k4.x);
kv_shmem[kv_off + 1u] = KV_TYPE(k4.y);
kv_shmem[kv_off + 2u] = KV_TYPE(k4.z);
kv_shmem[kv_off + 3u] = KV_TYPE(k4.w);
}
#ifndef KV_DIRECT
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
#endif
workgroupBarrier();
@@ -238,8 +293,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
q_shmem[q_off + 1u],
q_shmem[q_off + 2u],
q_shmem[q_off + 3u]);
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let kv = vec4<KV_TYPE>(
let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u;
let kv = vec4<f16>(
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
@@ -271,25 +326,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let kv_local = sg_inv_id + slot * subgroup_size;
if (row_active && kv_local < kv_count) {
let p = exp(local_scores[slot] - new_max);
p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p);
p_shmem[subgroup_p_offset + kv_local] = f16(p);
local_sum += p;
}
}
workgroupBarrier();
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
let kv_local = vec_idx_local / V_CHUNKS;
let chunk = vec_idx_local % V_CHUNKS;
let global_v_row = kv_tile + kv_local;
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
let v4 = V[v_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = KV_TYPE(v4.x);
kv_shmem[kv_off + 1u] = KV_TYPE(v4.y);
kv_shmem[kv_off + 2u] = KV_TYPE(v4.z);
kv_shmem[kv_off + 3u] = KV_TYPE(v4.w);
}
#ifndef KV_DIRECT
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
#endif
workgroupBarrier();
@@ -306,14 +352,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
var acc = out_regs[reg_idx];
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
let p = p_shmem[subgroup_p_offset + kv_local];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let v4 = vec4<KV_TYPE>(
let p = f32(p_shmem[subgroup_p_offset + kv_local]);
let kv_off = kv_local * HEAD_DIM_V + chunk * 4u;
let v4 = vec4<f16>(
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
kv_shmem[kv_off + 3u]);
acc += f32(p) * vec4<f32>(v4);
acc += p * vec4<f32>(v4);
}
out_regs[reg_idx] = acc;
}
@@ -2,10 +2,23 @@ diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;
#ifdef KV_F32
#define KV_TYPE f32
#define BYTE_HELPERS
#include "common_decls.tmpl"
#ifdef K_F32
#define K_TYPE f32
#elif defined(K_Q4_0) || defined(K_Q8_0)
#define K_TYPE u32
#else
#define KV_TYPE f16
#define K_TYPE f16
#endif
#ifdef V_F32
#define V_TYPE f32
#elif defined(V_Q4_0) || defined(V_Q8_0)
#define V_TYPE u32
#else
#define V_TYPE f16
#endif
#ifdef Q_F16
@@ -32,28 +45,6 @@ enable subgroups;
#define KV_BLOCKS (KV_TILE / KV_GRANULARITY)
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
#if defined(KV_Q4_0)
#define NQ 16
#define F16_PER_BLOCK 9
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
#define F16_PER_BLOCK 17
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}
fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}
struct Params {
offset_q: u32,
offset_k: u32,
@@ -103,22 +94,22 @@ struct Params {
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
#ifdef KV_OVERLAP
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
#if defined(K_Q4_0) || defined(K_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
#endif
#define V K
#else
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
#if defined(K_Q4_0) || defined(K_Q8_0)
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
#endif
#if defined(KV_Q4_0) || defined(KV_Q8_0)
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
#if defined(V_Q4_0) || defined(V_Q8_0)
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
#else
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>;
#endif
#endif
#if defined(MASK) && defined(SINKS)
@@ -244,6 +235,49 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool)
return v;
}
#ifndef KV_DIRECT
#define QUANT_SHMEM kv_shmem
#define QUANT_OUT_TYPE f32
#include "quant_inner_loops.tmpl"
#include "flash_attn_quant_staging.tmpl"
#if !defined(K_Q4_0) && !defined(K_Q8_0)
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
let vec_idx = (global_k_row_offset + k_col) >> 2u;
let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f32(k4.x);
kv_shmem[elem_idx + 1u] = f32(k4.y);
kv_shmem[elem_idx + 2u] = f32(k4.z);
kv_shmem[elem_idx + 3u] = f32(k4.w);
}
}
#endif
#if !defined(V_Q4_0) && !defined(V_Q8_0)
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
let vec_idx = (global_v_row_offset + v_col) >> 2u;
let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f32(v4.x);
kv_shmem[elem_idx + 1u] = f32(v4.y);
kv_shmem[elem_idx + 2u] = f32(v4.z);
kv_shmem[elem_idx + 3u] = f32(v4.w);
}
}
#endif
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@@ -308,6 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
#ifdef BLK
let q_blk = q_row_start;
let kv_blk = kv_tile / KV_TILE;
@@ -324,76 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;
if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = K[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = K[base_idx + 1u + block_offset + j];
let q_1 = K[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
let k_row = elem_idx / HEAD_DIM_QK;
let k_col = elem_idx % HEAD_DIM_QK;
let global_k_row = kv_tile + k_row;
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
let vec_idx = (global_k_row_offset + k_col) >> 2u;
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f32(k4.x);
kv_shmem[elem_idx + 1u] = f32(k4.y);
kv_shmem[elem_idx + 2u] = f32(k4.z);
kv_shmem[elem_idx + 3u] = f32(k4.w);
}
#ifndef KV_DIRECT
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
#endif
workgroupBarrier();
@@ -510,76 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}
// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;
if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let base_idx = global_block_idx * F16_PER_BLOCK;
let d = V[base_idx];
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_0 = V[base_idx + 1u + block_offset + j];
let q_1 = V[base_idx + 1u + block_offset + j + 1];
let q_packed = bitcast<u32>(vec2(q_0, q_1));
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f32(q_byte) * f32(d);
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
let v_row = elem_idx / HEAD_DIM_V;
let v_col = elem_idx % HEAD_DIM_V;
let global_v_row = kv_tile + v_row;
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
let vec_idx = (global_v_row_offset + v_col) >> 2u;
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
kv_shmem[elem_idx + 0u] = f32(v4.x);
kv_shmem[elem_idx + 1u] = f32(v4.y);
kv_shmem[elem_idx + 2u] = f32(v4.z);
kv_shmem[elem_idx + 3u] = f32(v4.w);
}
#ifndef KV_DIRECT
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
#endif
workgroupBarrier();
@@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) {
}
#endif // SCALAR
#define QUANT_SHMEM shmem
#define QUANT_OUT_TYPE f16
#include "quant_inner_loops.tmpl"
#ifdef INIT_SRC0_SHMEM_FLOAT
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
@@ -124,14 +128,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
@@ -314,12 +311,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
}
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
@@ -0,0 +1,21 @@
#ifdef U32_DEQUANT_HELPERS
fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) {
let scale = QUANT_OUT_TYPE(d);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale;
let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale;
QUANT_SHMEM[dst_idx + k] = q_lo;
QUANT_SHMEM[dst_idx + k + 16u] = q_hi;
}
}
fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) {
let scale = QUANT_OUT_TYPE(d);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = QUANT_OUT_TYPE(q_byte) * scale;
QUANT_SHMEM[dst_idx + k] = q_val;
}
}
#endif