mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
ggml-webgpu: FlashAttention refactor + standardize quantization support (#23834)
* Start work on flash_attn refactor * Refactor * Split k/v quantization * Refactor and abstract quantization logic for flash_attn and mul_mat * Add quantization support to tile path * formatting * Move to functions, add a check
This commit is contained in:
@@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})
|
||||
|
||||
message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")
|
||||
|
||||
# Find all WGSL files
|
||||
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
|
||||
# Find all WGSL sources
|
||||
file(GLOB WGSL_SHADER_FILES
|
||||
"${SHADER_DIR}/*.wgsl"
|
||||
"${SHADER_DIR}/*.tmpl"
|
||||
)
|
||||
|
||||
# Generate the header using a Python script
|
||||
add_custom_command(
|
||||
|
||||
@@ -18,6 +18,9 @@
|
||||
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
||||
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
||||
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
||||
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
||||
@@ -546,16 +549,10 @@ struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
|
||||
/** FlashAttention */
|
||||
|
||||
enum ggml_webgpu_flash_attn_path : uint32_t {
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u,
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u,
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
struct ggml_webgpu_flash_attn_common_pipeline_key {
|
||||
ggml_type q_type;
|
||||
ggml_type kv_type;
|
||||
ggml_type k_type;
|
||||
ggml_type v_type;
|
||||
ggml_type dst_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
@@ -564,93 +561,224 @@ struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
uint32_t path;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const {
|
||||
return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type &&
|
||||
dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask &&
|
||||
has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap;
|
||||
}
|
||||
};
|
||||
|
||||
inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed,
|
||||
const ggml_webgpu_flash_attn_common_pipeline_key & key) {
|
||||
ggml_webgpu_hash_combine(seed, key.q_type);
|
||||
ggml_webgpu_hash_combine(seed, key.k_type);
|
||||
ggml_webgpu_hash_combine(seed, key.v_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_pipeline_key {
|
||||
ggml_webgpu_flash_attn_common_pipeline_key common;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
ggml_webgpu_flash_attn_common_pipeline_key common;
|
||||
bool use_sg_matrix;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
|
||||
head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
|
||||
kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap && path == other.path;
|
||||
return common == other.common && use_sg_matrix == other.use_sg_matrix;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.q_type);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_type);
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_overlap);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
ggml_webgpu_hash_combine(seed, key.path);
|
||||
ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common);
|
||||
ggml_webgpu_hash_combine(seed, key.use_sg_matrix);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_decisions {
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_decisions {
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
bool kv_direct = false;
|
||||
bool kv_overlap = false;
|
||||
bool use_sg_matrix = false;
|
||||
uint32_t q_tile = 0;
|
||||
uint32_t kv_tile = 0;
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u;
|
||||
inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u;
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 ||
|
||||
key.head_dim_qk != key.head_dim_v) {
|
||||
return 1u;
|
||||
}
|
||||
|
||||
switch (key.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
return 2u;
|
||||
case 96:
|
||||
return 4u;
|
||||
default:
|
||||
return 1u;
|
||||
}
|
||||
inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
|
||||
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
||||
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
||||
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_decisions & decisions) {
|
||||
const bool has_mask = context.src3 != nullptr;
|
||||
const bool has_sinks = context.src4 != nullptr;
|
||||
bool kv_direct = false;
|
||||
if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH;
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
kv_direct_align = context.sg_mat_k;
|
||||
}
|
||||
kv_direct = (context.src1->type == GGML_TYPE_F16) &&
|
||||
(context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) &&
|
||||
(context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
|
||||
const uint32_t offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
|
||||
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
|
||||
const ggml_tensor * V,
|
||||
size_t storage_offset_alignment) {
|
||||
return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) &&
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_flash_attn_kv_direct(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
|
||||
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
uint32_t kv_direct_align) {
|
||||
ggml_webgpu_flash_attn_common_pipeline_key key = {};
|
||||
key.q_type = context.src0->type;
|
||||
key.k_type = context.src1->type;
|
||||
key.v_type = context.src2->type;
|
||||
key.dst_type = context.dst->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = context.src3 != nullptr;
|
||||
key.has_sinks = context.src4 != nullptr;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
return key;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines(
|
||||
const ggml_webgpu_flash_attn_common_pipeline_key & key,
|
||||
std::string & variant,
|
||||
uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
uint32_t wg_size) {
|
||||
std::vector<std::string> defines;
|
||||
|
||||
switch (key.k_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("K_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("K_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("K_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("K_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported K type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_k") + ggml_type_name(key.k_type);
|
||||
|
||||
switch (key.v_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("V_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("V_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("V_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("V_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported V type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_v") + ggml_type_name(key.v_type);
|
||||
|
||||
switch (key.q_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("Q_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("Q_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported Q type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_q") + ggml_type_name(key.q_type);
|
||||
|
||||
switch (key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (key.kv_overlap) {
|
||||
defines.push_back("KV_OVERLAP");
|
||||
variant += "_kv_overlap";
|
||||
}
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.q_type = context.src0->type;
|
||||
key.kv_type = context.src1->type;
|
||||
key.dst_type = context.dst->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = kv_direct;
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = has_mask;
|
||||
key.has_sinks = has_sinks;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
key.path = decisions.path;
|
||||
return key;
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) {
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
}
|
||||
|
||||
return defines;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
|
||||
@@ -688,29 +816,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
|
||||
}
|
||||
};
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
// Note: this will slightly overestimate memory usage for vec path
|
||||
// since row_max and exp_sum shmem are not needed.
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
uint32_t head_dim_qk,
|
||||
uint32_t head_dim_v,
|
||||
bool has_mask,
|
||||
bool kv_direct,
|
||||
uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) {
|
||||
bool kv_direct) {
|
||||
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
|
||||
size_t f16_elems = 0;
|
||||
size_t f32_elems = 0;
|
||||
if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
f32_elems += head_dim_qk; // q_shmem
|
||||
if (!kv_direct) {
|
||||
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
}
|
||||
f32_elems += head_dim_v; // o_shmem
|
||||
if (has_mask) {
|
||||
f32_elems += kv_tile; // mask_shmem
|
||||
}
|
||||
f32_elems += kv_tile; // inter_shmem
|
||||
return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
}
|
||||
|
||||
f32_elems += q_tile * head_dim_qk; // q_shmem
|
||||
if (!kv_direct) {
|
||||
f32_elems += kv_tile * max_head_dim; // kv_shmem
|
||||
@@ -725,25 +842,20 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
}
|
||||
|
||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context,
|
||||
const ggml_webgpu_flash_attn_pipeline_key & key) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_granularity = std::max(1u, context.sg_mat_n);
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
kv_granularity = 1u;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
q_tile = 1u;
|
||||
kv_granularity = 8u;
|
||||
}
|
||||
const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v,
|
||||
key.has_mask, key.kv_direct, key.path);
|
||||
inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes,
|
||||
uint32_t q_tile,
|
||||
uint32_t kv_granularity,
|
||||
uint32_t head_dim_qk,
|
||||
uint32_t head_dim_v,
|
||||
bool has_mask,
|
||||
bool kv_direct) {
|
||||
const size_t base_q_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
||||
if (limit_bytes <= base_q_bytes) {
|
||||
return 0;
|
||||
}
|
||||
const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v,
|
||||
key.has_mask, key.kv_direct, key.path);
|
||||
const size_t one_kv_bytes =
|
||||
ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
||||
const size_t bytes_per_kv = one_kv_bytes - base_q_bytes;
|
||||
if (bytes_per_kv == 0) {
|
||||
return 0;
|
||||
@@ -752,105 +864,32 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_
|
||||
return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity);
|
||||
}
|
||||
|
||||
inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions(
|
||||
const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
ggml_webgpu_flash_attn_decisions decisions = {};
|
||||
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
|
||||
const auto * K = context.src1;
|
||||
const auto * V = context.src2;
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes,
|
||||
uint32_t head_dim_qk,
|
||||
uint32_t head_dim_v,
|
||||
bool has_mask,
|
||||
bool kv_direct) {
|
||||
const uint32_t max_kv_tile =
|
||||
ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct);
|
||||
GGML_ASSERT(max_kv_tile > 0);
|
||||
|
||||
const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t {
|
||||
constexpr uintptr_t ptr_base_addr = 0x1000u;
|
||||
const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor;
|
||||
return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs;
|
||||
};
|
||||
|
||||
const uint32_t k_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type));
|
||||
const uint32_t v_offset_elems =
|
||||
(uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type));
|
||||
const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) &&
|
||||
(v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u);
|
||||
const bool kv_vec_type_supported =
|
||||
K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const uint32_t kv_vec_head_align =
|
||||
K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type);
|
||||
const bool kv_vec_head_dims_aligned =
|
||||
context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0;
|
||||
// Compile with enough invocations to cover the largest reported subgroup.
|
||||
const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned &&
|
||||
kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) &&
|
||||
(context.src2->type == K->type);
|
||||
const bool tile_can_dispatch_all_q_rows =
|
||||
context.max_subgroup_size > 0 &&
|
||||
context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size;
|
||||
const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 &&
|
||||
context.src0->ne[0] % context.sg_mat_k == 0 &&
|
||||
context.src2->ne[0] % context.sg_mat_n == 0;
|
||||
const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 &&
|
||||
V->type == GGML_TYPE_F16 && f16_vec4_aligned &&
|
||||
(context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
(context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) &&
|
||||
tile_can_dispatch_all_q_rows && !use_vec;
|
||||
|
||||
decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC :
|
||||
use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE :
|
||||
use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX :
|
||||
GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||
return decisions;
|
||||
}
|
||||
|
||||
const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
|
||||
decisions.kv_direct = key.kv_direct;
|
||||
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key);
|
||||
// invalidate if even the smallest kv_tile doesn't fit in shared memory
|
||||
if (max_kv_tile == 0) {
|
||||
decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE;
|
||||
return decisions;
|
||||
}
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
decisions.q_tile = 1u;
|
||||
decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile));
|
||||
decisions.kv_tile = (decisions.kv_tile / 8u) * 8u;
|
||||
decisions.wg_size = context.max_subgroup_size;
|
||||
if (decisions.kv_direct) {
|
||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= 8u;
|
||||
}
|
||||
}
|
||||
return decisions;
|
||||
}
|
||||
|
||||
decisions.q_tile =
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m;
|
||||
decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::min(64u, max_kv_tile) :
|
||||
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ?
|
||||
std::min(std::max(1u, context.max_wg_size),
|
||||
std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) :
|
||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
|
||||
|
||||
if (decisions.kv_tile == 0) {
|
||||
return decisions;
|
||||
}
|
||||
|
||||
if (decisions.kv_direct) {
|
||||
GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -=
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n;
|
||||
uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile);
|
||||
if (kv_direct) {
|
||||
kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
kv_tile -= 1u;
|
||||
}
|
||||
}
|
||||
return decisions;
|
||||
|
||||
return kv_tile;
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix,
|
||||
uint32_t sg_mat_k,
|
||||
uint32_t sg_mat_n,
|
||||
const ggml_tensor * Q,
|
||||
const ggml_tensor * V) {
|
||||
return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0;
|
||||
}
|
||||
|
||||
/** Matrix Multiplication **/
|
||||
@@ -1123,6 +1162,10 @@ class ggml_webgpu_shader_lib {
|
||||
concat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
|
||||
repeat_pipelines; // type
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key,
|
||||
webgpu_pipeline,
|
||||
ggml_webgpu_flash_attn_vec_pipeline_key_hash>
|
||||
flash_attn_vec_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
|
||||
flash_attn_pipelines;
|
||||
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
|
||||
@@ -1835,10 +1878,10 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_mul_mat_vec_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
key.use_mmvq =
|
||||
ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor);
|
||||
|
||||
@@ -1971,11 +2014,11 @@ class ggml_webgpu_shader_lib {
|
||||
ggml_webgpu_mul_mat_pipeline_key key = {};
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
key.use_subgroup_matrix = context.supports_subgroup_matrix;
|
||||
|
||||
auto it = mul_mat_fast_pipelines.find(key);
|
||||
if (it != mul_mat_fast_pipelines.end()) {
|
||||
@@ -2148,10 +2191,10 @@ class ggml_webgpu_shader_lib {
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.n_experts = context.src0->ne[2];
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_id_pipelines.find(key);
|
||||
if (it != mul_mat_id_pipelines.end()) {
|
||||
@@ -2271,10 +2314,10 @@ class ggml_webgpu_shader_lib {
|
||||
key.src0_type = context.src0->type;
|
||||
key.src1_type = context.src1->type;
|
||||
key.n_experts = context.src0->ne[2];
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
key.vectorized = (context.src0->ne[0] % 4 == 0 &&
|
||||
(context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
|
||||
1 :
|
||||
0;
|
||||
1 :
|
||||
0;
|
||||
|
||||
auto it = mul_mat_id_vec_pipelines.find(key);
|
||||
if (it != mul_mat_id_vec_pipelines.end()) {
|
||||
@@ -2664,119 +2707,62 @@ class ggml_webgpu_shader_lib {
|
||||
return repeat_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context,
|
||||
size_t storage_offset_alignment) {
|
||||
const ggml_webgpu_flash_attn_decisions decisions =
|
||||
ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment);
|
||||
GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE);
|
||||
ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions);
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
|
||||
context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2);
|
||||
ggml_webgpu_flash_attn_decisions decisions = {};
|
||||
decisions.use_sg_matrix = can_use_subgroup_matrix;
|
||||
decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
|
||||
ggml_webgpu_flash_attn_pipeline_key key = {};
|
||||
key.common =
|
||||
ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u);
|
||||
key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct;
|
||||
key.use_sg_matrix = decisions.use_sg_matrix;
|
||||
|
||||
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
|
||||
context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u,
|
||||
key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
||||
GGML_ASSERT(max_kv_tile > 0);
|
||||
|
||||
decisions.kv_tile = decisions.use_sg_matrix ?
|
||||
std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) :
|
||||
std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile);
|
||||
decisions.wg_size =
|
||||
decisions.use_sg_matrix ?
|
||||
std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) :
|
||||
std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE,
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size));
|
||||
|
||||
if (key.common.kv_direct) {
|
||||
decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD);
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) {
|
||||
decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size;
|
||||
}
|
||||
}
|
||||
|
||||
auto it = flash_attn_pipelines.find(key);
|
||||
if (it != flash_attn_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" :
|
||||
decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" :
|
||||
"flash_attn";
|
||||
|
||||
switch (key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("KV_F16");
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
defines.push_back("KV_Q4_0");
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
defines.push_back("KV_Q8_0");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(key.kv_type);
|
||||
|
||||
switch (key.q_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("Q_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("Q_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported Q type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_q") + ggml_type_name(key.q_type);
|
||||
|
||||
switch (key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_dst") + ggml_type_name(key.dst_type);
|
||||
|
||||
if (key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("BLK");
|
||||
variant += "_mask_blk";
|
||||
} else {
|
||||
variant += "_mask";
|
||||
}
|
||||
}
|
||||
if (key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
if (key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
if (key.kv_overlap) {
|
||||
defines.push_back("KV_OVERLAP");
|
||||
variant += "_kv_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
|
||||
|
||||
const char * shader_src = wgsl_flash_attn;
|
||||
if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
defines.push_back("KV_GRANULARITY=8");
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u");
|
||||
shader_src = wgsl_flash_attn_vec_split;
|
||||
} else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile";
|
||||
std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile,
|
||||
decisions.kv_tile, decisions.wg_size);
|
||||
const char * shader_src = nullptr;
|
||||
if (!key.use_sg_matrix) {
|
||||
shader_src = wgsl_flash_attn_tile;
|
||||
defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u");
|
||||
defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
|
||||
defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v)));
|
||||
variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" +
|
||||
std::to_string(context.max_subgroup_size);
|
||||
} else {
|
||||
shader_src = wgsl_flash_attn;
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
|
||||
}
|
||||
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
pipeline_decisions->kv_overlap = key.kv_overlap;
|
||||
defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile));
|
||||
defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile));
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size));
|
||||
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions);
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
@@ -2784,6 +2770,55 @@ class ggml_webgpu_shader_lib {
|
||||
return flash_attn_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_flash_attn_vec_pipeline_key key = {};
|
||||
key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
|
||||
|
||||
auto it = flash_attn_vec_pipelines.find(key);
|
||||
if (it != flash_attn_vec_pipelines.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ggml_webgpu_flash_attn_vec_decisions decisions = {};
|
||||
decisions.kv_tile =
|
||||
ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk,
|
||||
key.common.head_dim_v, key.common.has_mask, key.common.kv_direct);
|
||||
decisions.wg_size = context.max_subgroup_size;
|
||||
|
||||
std::string variant = "flash_attn_vec";
|
||||
std::vector<std::string> defines =
|
||||
ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size);
|
||||
if (key.common.has_mask) {
|
||||
defines.push_back("BLK");
|
||||
variant.resize(variant.size() - (sizeof("_mask") - 1));
|
||||
variant += "_mask_blk";
|
||||
}
|
||||
uint32_t vec_ne = 1u;
|
||||
if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 &&
|
||||
key.common.head_dim_qk == key.common.head_dim_v) {
|
||||
switch (key.common.head_dim_qk) {
|
||||
case 64:
|
||||
case 192:
|
||||
case 576:
|
||||
vec_ne = 2u;
|
||||
break;
|
||||
case 96:
|
||||
vec_ne = 4u;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
|
||||
|
||||
auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions);
|
||||
webgpu_pipeline pipeline =
|
||||
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant);
|
||||
pipeline.context = pipeline_decisions;
|
||||
flash_attn_vec_pipelines[key] = pipeline;
|
||||
return flash_attn_vec_pipelines[key];
|
||||
}
|
||||
|
||||
webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) {
|
||||
ggml_webgpu_flash_attn_blk_pipeline_key key = {};
|
||||
key.kv_tile = kv_tile;
|
||||
|
||||
@@ -1755,13 +1755,50 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst) {
|
||||
struct ggml_webgpu_flash_attn_op {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
std::vector<uint32_t> params;
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
bool has_mask = false;
|
||||
bool has_sinks = false;
|
||||
bool kv_overlap = false;
|
||||
};
|
||||
|
||||
static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx,
|
||||
const ggml_tensor * Q,
|
||||
const ggml_tensor * K,
|
||||
const ggml_tensor * V) {
|
||||
const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) ||
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment);
|
||||
const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) ||
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
|
||||
const bool k_vec_type_supported =
|
||||
K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
|
||||
const bool v_vec_type_supported =
|
||||
V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0;
|
||||
const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ?
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(K->type);
|
||||
const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ?
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(V->type);
|
||||
const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0;
|
||||
|
||||
return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) &&
|
||||
kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned &&
|
||||
v_float_vec4_aligned;
|
||||
}
|
||||
|
||||
static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst) {
|
||||
float scale = ggml_get_op_params_f32(dst, 0);
|
||||
float max_bias = ggml_get_op_params_f32(dst, 1);
|
||||
float logit_softcap = ggml_get_op_params_f32(dst, 2);
|
||||
@@ -1772,47 +1809,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
|
||||
float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = Q;
|
||||
shader_lib_ctx.src1 = K;
|
||||
shader_lib_ctx.src2 = V;
|
||||
shader_lib_ctx.src3 = mask;
|
||||
shader_lib_ctx.src4 = sinks;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
|
||||
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
const int has_mask = (mask != nullptr);
|
||||
const int has_sinks = (sinks != nullptr);
|
||||
const bool kv_overlap = decisions->kv_overlap;
|
||||
ggml_webgpu_flash_attn_op op = {};
|
||||
op.shader_lib_ctx.src0 = Q;
|
||||
op.shader_lib_ctx.src1 = K;
|
||||
op.shader_lib_ctx.src2 = V;
|
||||
op.shader_lib_ctx.src3 = mask;
|
||||
op.shader_lib_ctx.src4 = sinks;
|
||||
op.shader_lib_ctx.dst = dst;
|
||||
op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups;
|
||||
op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix;
|
||||
op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
|
||||
op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
|
||||
op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
|
||||
op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
size_t kv_bind_offset = 0;
|
||||
size_t kv_bind_size = 0;
|
||||
if (kv_overlap) {
|
||||
op.has_mask = mask != nullptr;
|
||||
op.has_sinks = sinks != nullptr;
|
||||
op.kv_overlap = ggml_webgpu_tensor_overlap(K, V);
|
||||
|
||||
uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
|
||||
uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
|
||||
if (op.kv_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V });
|
||||
kv_bind_offset = merged_range.offset;
|
||||
kv_bind_size = merged_range.size;
|
||||
op.kv_bind_offset = merged_range.offset;
|
||||
op.kv_bind_size = merged_range.size;
|
||||
offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range);
|
||||
offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
op.params = {
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
|
||||
offset_k,
|
||||
offset_v,
|
||||
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
|
||||
op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) Q->ne[2], // number of heads
|
||||
(uint32_t) Q->ne[1], // sequence length (Q)
|
||||
@@ -1826,7 +1859,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
(uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1
|
||||
(uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2
|
||||
(uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3
|
||||
has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
||||
op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3
|
||||
(uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
|
||||
ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap)
|
||||
ggml_webgpu_u32_from_f32(max_bias),
|
||||
@@ -1834,32 +1867,56 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_webgpu_u32_from_f32(n_head_log2),
|
||||
ggml_webgpu_u32_from_f32(m0),
|
||||
ggml_webgpu_u32_from_f32(m1)
|
||||
|
||||
};
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
op.entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q),
|
||||
};
|
||||
if (kv_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
if (op.kv_overlap) {
|
||||
op.entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
|
||||
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K));
|
||||
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V));
|
||||
}
|
||||
uint32_t binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
||||
uint32_t binding_index = op.kv_overlap ? 2u : 3u;
|
||||
if (op.has_mask) {
|
||||
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask));
|
||||
}
|
||||
if (has_sinks) {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks));
|
||||
if (op.has_sinks) {
|
||||
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks));
|
||||
}
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
|
||||
op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst));
|
||||
|
||||
if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
return op;
|
||||
}
|
||||
|
||||
static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) {
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) kv_tile;
|
||||
while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
return std::min(nwg, vec_nwg_cap);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) {
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile);
|
||||
uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3];
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst,
|
||||
ggml_webgpu_flash_attn_op op) {
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get());
|
||||
|
||||
wgpu::Buffer blk_buf = {};
|
||||
uint64_t blk_size_bytes = 0;
|
||||
@@ -1868,13 +1925,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
uint32_t blk_batch_count = 0;
|
||||
|
||||
const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size;
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]);
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
const bool use_vec_reduce = nwg > 1u;
|
||||
GGML_ASSERT(nrows <= UINT32_MAX);
|
||||
|
||||
@@ -1910,7 +1962,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
webgpu_pipeline blk_pipeline;
|
||||
std::vector<uint32_t> blk_params;
|
||||
std::vector<wgpu::BindGroupEntry> blk_entries;
|
||||
if (has_mask) {
|
||||
if (op.has_mask) {
|
||||
blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
|
||||
blk_nblk1 = (uint32_t) Q->ne[1];
|
||||
blk_buf = ggml_webgpu_tensor_buf(dst);
|
||||
@@ -1918,7 +1970,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx;
|
||||
const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx;
|
||||
blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile);
|
||||
|
||||
blk_params = {
|
||||
@@ -1938,8 +1990,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> split_params = params;
|
||||
if (has_mask) {
|
||||
std::vector<uint32_t> split_params = op.params;
|
||||
if (op.has_mask) {
|
||||
split_params.push_back(0u); // blk_base
|
||||
split_params.push_back(blk_nblk0); // blk_nblk0
|
||||
split_params.push_back(blk_nblk1); // blk_nblk1
|
||||
@@ -1952,9 +2004,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q),
|
||||
ggml_webgpu_tensor_binding_size(ctx, Q)),
|
||||
};
|
||||
if (kv_overlap) {
|
||||
if (op.kv_overlap) {
|
||||
split_entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size));
|
||||
ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size));
|
||||
} else {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K),
|
||||
ggml_webgpu_tensor_align_offset(ctx, K),
|
||||
@@ -1963,18 +2015,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_webgpu_tensor_align_offset(ctx, V),
|
||||
ggml_webgpu_tensor_binding_size(ctx, V)));
|
||||
}
|
||||
uint32_t split_binding_index = kv_overlap ? 2u : 3u;
|
||||
if (has_mask) {
|
||||
uint32_t split_binding_index = op.kv_overlap ? 2u : 3u;
|
||||
if (op.has_mask) {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask),
|
||||
ggml_webgpu_tensor_align_offset(ctx, mask),
|
||||
ggml_webgpu_tensor_binding_size(ctx, mask)));
|
||||
}
|
||||
if (has_sinks) {
|
||||
if (op.has_sinks) {
|
||||
split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks),
|
||||
ggml_webgpu_tensor_align_offset(ctx, sinks),
|
||||
ggml_webgpu_tensor_binding_size(ctx, sinks)));
|
||||
}
|
||||
if (has_mask) {
|
||||
if (op.has_mask) {
|
||||
split_entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes));
|
||||
}
|
||||
@@ -1993,7 +2045,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
reduce_sg_size,
|
||||
(uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size,
|
||||
ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
|
||||
ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx;
|
||||
ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx;
|
||||
reduce_shader_ctx.max_wg_size = reduce_wg_size;
|
||||
reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
|
||||
|
||||
@@ -2020,7 +2072,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
|
||||
std::vector<webgpu_dispatch_desc> dispatches;
|
||||
|
||||
if (has_mask) {
|
||||
if (op.has_mask) {
|
||||
dispatches.push_back({
|
||||
blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count }
|
||||
});
|
||||
@@ -2037,6 +2089,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
return ggml_backend_webgpu_build_multi(ctx, dispatches);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
|
||||
ggml_tensor * Q,
|
||||
ggml_tensor * K,
|
||||
ggml_tensor * V,
|
||||
ggml_tensor * mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * dst) {
|
||||
ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst);
|
||||
if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) {
|
||||
return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op));
|
||||
}
|
||||
return ggml_webgpu_flash_attn_direct(ctx, op);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
bool is_unary = dst->op == GGML_OP_UNARY;
|
||||
|
||||
@@ -3553,70 +3619,43 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const ggml_tensor * Q = tensor->src[0];
|
||||
const ggml_tensor * K = tensor->src[1];
|
||||
const ggml_tensor * V = tensor->src[2];
|
||||
const ggml_tensor * mask = tensor->src[3];
|
||||
const ggml_tensor * sinks = tensor->src[4];
|
||||
if (Q && K && V) {
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = const_cast<ggml_tensor *>(Q);
|
||||
shader_lib_ctx.src1 = const_cast<ggml_tensor *>(K);
|
||||
shader_lib_ctx.src2 = const_cast<ggml_tensor *>(V);
|
||||
shader_lib_ctx.src3 = const_cast<ggml_tensor *>(mask);
|
||||
shader_lib_ctx.src4 = const_cast<ggml_tensor *>(sinks);
|
||||
shader_lib_ctx.dst = const_cast<ggml_tensor *>(tensor);
|
||||
shader_lib_ctx.max_wg_size =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix =
|
||||
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
const ggml_tensor * Q = tensor->src[0];
|
||||
const ggml_tensor * K = tensor->src[1];
|
||||
const ggml_tensor * V = tensor->src[2];
|
||||
const ggml_tensor * mask = tensor->src[3];
|
||||
const auto & capabilities = ctx->webgpu_global_ctx->capabilities;
|
||||
if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) {
|
||||
const bool kv_direct =
|
||||
ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH);
|
||||
const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile(
|
||||
capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0],
|
||||
mask != nullptr, kv_direct);
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const uint32_t vec_nwg_cap = capabilities.min_subgroup_size;
|
||||
uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]);
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const uint32_t kv_tile = decisions.kv_tile;
|
||||
|
||||
const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
uint32_t nwg = 1u;
|
||||
const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
|
||||
while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
|
||||
nwg <<= 1;
|
||||
}
|
||||
nwg = std::min(nwg, vec_nwg_cap);
|
||||
|
||||
const size_t align =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
if (nwg > 1u) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2(
|
||||
(tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
} else {
|
||||
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
|
||||
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
const size_t blk_size_bytes =
|
||||
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += blk_size_bytes + align;
|
||||
}
|
||||
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
const size_t align = capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
|
||||
if (nwg > 1u) {
|
||||
const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
|
||||
const uint64_t tmp_stats_elems = nrows * 2u * nwg;
|
||||
const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float),
|
||||
WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += tmp_size_bytes + align;
|
||||
} else {
|
||||
res += WEBGPU_STORAGE_BUF_BINDING_MULT + align;
|
||||
}
|
||||
if (mask != nullptr) {
|
||||
const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
|
||||
const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
|
||||
const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
|
||||
const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
|
||||
const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
|
||||
const size_t blk_size_bytes =
|
||||
ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
res += blk_size_bytes + align;
|
||||
}
|
||||
res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -4139,70 +4178,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
// conservative support checks for whether the more resource-intensive shader paths
|
||||
// can be used, to avoid cases where flash_attn is assigned to the CPU later on
|
||||
supports_op = src0->type == GGML_TYPE_F32 &&
|
||||
(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
|
||||
src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
|
||||
src2->type == src1->type && op->type == GGML_TYPE_F32;
|
||||
(src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 ||
|
||||
src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
if (!supports_op) {
|
||||
break;
|
||||
}
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
shader_lib_ctx.src2 = src2;
|
||||
shader_lib_ctx.src3 = op->src[3];
|
||||
shader_lib_ctx.src4 = op->src[4];
|
||||
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
|
||||
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
|
||||
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
|
||||
shader_lib_ctx.max_wg_size =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.wg_mem_limit_bytes =
|
||||
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
|
||||
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
|
||||
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
|
||||
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
|
||||
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;
|
||||
|
||||
const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
|
||||
shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
|
||||
const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
|
||||
const bool has_mask = op->src[3] != nullptr;
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) {
|
||||
if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type &&
|
||||
!ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) {
|
||||
supports_op = false;
|
||||
break;
|
||||
}
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
const auto & capabilities = ctx->webgpu_global_ctx->capabilities;
|
||||
const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment;
|
||||
|
||||
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
// subgroup matrix path requirements
|
||||
const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(
|
||||
capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2);
|
||||
|
||||
if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
|
||||
// tile path requirements
|
||||
const bool float_vec4_aligned =
|
||||
((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) ||
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) &&
|
||||
((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) ||
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment));
|
||||
const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ?
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(src1->type);
|
||||
const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ?
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH :
|
||||
(uint32_t) ggml_blck_size(src2->type);
|
||||
const bool tile_kv_head_dims_aligned =
|
||||
src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0;
|
||||
const bool tile_can_dispatch_all_q_rows =
|
||||
capabilities.limits.maxComputeInvocationsPerWorkgroup >=
|
||||
GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size;
|
||||
const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned &&
|
||||
tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows;
|
||||
|
||||
if (!use_subgroup_matrix && !use_tile) {
|
||||
supports_op = false;
|
||||
break;
|
||||
}
|
||||
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
|
||||
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
|
||||
decisions.kv_direct, decisions.path);
|
||||
if (min_bytes > limit_bytes) {
|
||||
supports_op = false;
|
||||
}
|
||||
const uint32_t q_tile =
|
||||
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
|
||||
const bool kv_direct = use_subgroup_matrix ?
|
||||
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
|
||||
false;
|
||||
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
|
||||
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
|
||||
supports_op = max_kv_tile > 0;
|
||||
break;
|
||||
}
|
||||
case GGML_OP_RMS_NORM:
|
||||
|
||||
@@ -37,15 +37,33 @@ static std::string trim(const std::string & s) {
|
||||
}
|
||||
|
||||
static std::string trim_value(std::istream & is) {
|
||||
std::string str;
|
||||
std::getline(is, str);
|
||||
return trim(str);
|
||||
std::ostringstream ss;
|
||||
ss << is.rdbuf();
|
||||
return trim(ss.str());
|
||||
}
|
||||
|
||||
static bool isIdentChar(char c) {
|
||||
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
|
||||
}
|
||||
|
||||
static bool endsWithContinuation(const std::string & line) {
|
||||
size_t i = line.size();
|
||||
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
|
||||
i--;
|
||||
}
|
||||
return i > 0 && line[i - 1] == '\\';
|
||||
}
|
||||
|
||||
static void stripContinuation(std::string & line) {
|
||||
size_t i = line.size();
|
||||
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
|
||||
i--;
|
||||
}
|
||||
if (i > 0 && line[i - 1] == '\\') {
|
||||
line.erase(i - 1);
|
||||
}
|
||||
}
|
||||
|
||||
static std::string expandMacrosRecursiveInternal(const std::string & line,
|
||||
const std::unordered_map<std::string, std::string> & macros,
|
||||
std::unordered_set<std::string> & visiting);
|
||||
@@ -595,19 +613,31 @@ class Preprocessor {
|
||||
std::string line;
|
||||
|
||||
while (std::getline(in, line)) {
|
||||
std::string t = trim(line);
|
||||
std::string logical = line;
|
||||
std::string t = trim(logical);
|
||||
if (!t.empty() && t[0] == '#') {
|
||||
while (endsWithContinuation(logical)) {
|
||||
stripContinuation(logical);
|
||||
if (!std::getline(in, line)) {
|
||||
break;
|
||||
}
|
||||
logical += "\n";
|
||||
logical += line;
|
||||
}
|
||||
t = trim(logical);
|
||||
}
|
||||
|
||||
if (!t.empty() && t[0] == '#') {
|
||||
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
|
||||
if (mode == DirectiveMode::IncludesOnly && !handled) {
|
||||
out << line << "\n";
|
||||
out << logical << "\n";
|
||||
}
|
||||
} else {
|
||||
if (mode == DirectiveMode::IncludesOnly) {
|
||||
out << line << "\n";
|
||||
out << logical << "\n";
|
||||
} else if (condActive(cond)) {
|
||||
// Expand macros in the line before outputting
|
||||
std::string expanded = expandMacrosRecursive(line, macros);
|
||||
std::string expanded = expandMacrosRecursive(logical, macros);
|
||||
out << expanded << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,23 @@ enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
#define KV_TYPE u32
|
||||
#define BYTE_HELPERS
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef K_F32
|
||||
#define K_TYPE f32
|
||||
#elif defined(K_Q4_0) || defined(K_Q8_0)
|
||||
#define K_TYPE u32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#define K_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef V_F32
|
||||
#define V_TYPE f32
|
||||
#elif defined(V_Q4_0) || defined(V_Q8_0)
|
||||
#define V_TYPE u32
|
||||
#else
|
||||
#define V_TYPE f16
|
||||
#endif
|
||||
|
||||
// Default values
|
||||
@@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix;
|
||||
// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
|
||||
#define KV_BLOCKS (KV_TILE / SG_MAT_N)
|
||||
|
||||
// Quantization constants/helpers
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
// number of quantized elements processed per thread
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
|
||||
#define F16_PER_BLOCK 9
|
||||
#define BLOCK_SIZE_BYTES 18u
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
|
||||
#define F16_PER_BLOCK 17
|
||||
#define BLOCK_SIZE_BYTES 34u
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
// Ok not to put these in a define block, compiler will remove if unused
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
fn load_k_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = K[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_k_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = K[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = K[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn load_v_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = V[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_v_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = V[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = V[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
|
||||
fn f16_from_u16(bits: u32) -> f16 {
|
||||
let packed = unpack2x16float(bits);
|
||||
return f16(packed[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
@@ -139,11 +80,11 @@ struct Params {
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@@ -238,10 +179,47 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32
|
||||
return (*buf)[scalar_index >> 2u];
|
||||
}
|
||||
|
||||
fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
|
||||
fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> {
|
||||
return (*buf)[scalar_index >> 2u];
|
||||
}
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
#define QUANT_SHMEM kv_shmem
|
||||
#define QUANT_OUT_TYPE f16
|
||||
#include "quant_inner_loops.tmpl"
|
||||
#include "flash_attn_quant_staging.tmpl"
|
||||
|
||||
#if !defined(K_Q4_0) && !defined(K_Q8_0)
|
||||
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
|
||||
for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
K[global_k_row_offset + k_col],
|
||||
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(V_Q4_0) && !defined(V_Q8_0)
|
||||
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
|
||||
for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
V[global_v_row_offset + v_col],
|
||||
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@@ -311,77 +289,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
}
|
||||
|
||||
for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
|
||||
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
|
||||
// clear inter_shmem to ensure zero-initialized accumulators
|
||||
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
|
||||
inter_shmem[elem_idx] = 0.0;
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
K[global_k_row_offset + k_col],
|
||||
global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
@@ -520,71 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f16(q_byte) * d;
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
kv_shmem[elem_idx] = f16(select(
|
||||
0.0,
|
||||
V[global_v_row_offset + v_col],
|
||||
global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
|
||||
#if defined(K_Q4_0)
|
||||
#define K_NQ 16
|
||||
#define K_BLOCK_SIZE_BYTES 18u
|
||||
#define K_BYTES_PER_THREAD 8u
|
||||
#define K_BYTES_PER_INNER_LOOP 4u
|
||||
#elif defined(K_Q8_0)
|
||||
#define K_NQ 16
|
||||
#define K_BLOCK_SIZE_BYTES 34u
|
||||
#define K_BYTES_PER_THREAD 16u
|
||||
#define K_BYTES_PER_INNER_LOOP 4u
|
||||
#endif
|
||||
|
||||
#if defined(V_Q4_0)
|
||||
#define V_NQ 16
|
||||
#define V_BLOCK_SIZE_BYTES 18u
|
||||
#define V_BYTES_PER_THREAD 8u
|
||||
#define V_BYTES_PER_INNER_LOOP 4u
|
||||
#elif defined(V_Q8_0)
|
||||
#define V_NQ 16
|
||||
#define V_BLOCK_SIZE_BYTES 34u
|
||||
#define V_BYTES_PER_THREAD 16u
|
||||
#define V_BYTES_PER_INNER_LOOP 4u
|
||||
#endif
|
||||
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
fn load_k_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = K[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_k_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = K[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = K[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(V_Q4_0) || defined(V_Q8_0)
|
||||
fn load_v_u16_at(byte_offset: u32) -> u32 {
|
||||
let word = V[byte_offset / 4u];
|
||||
let shift = (byte_offset & 2u) * 8u;
|
||||
return (word >> shift) & 0xFFFFu;
|
||||
}
|
||||
|
||||
fn load_v_u32_at(byte_offset: u32) -> u32 {
|
||||
let word_idx = byte_offset / 4u;
|
||||
let shift = (byte_offset & 3u) * 8u;
|
||||
let lo = V[word_idx];
|
||||
if (shift == 0u) {
|
||||
return lo;
|
||||
}
|
||||
let hi = V[word_idx + 1u];
|
||||
return (lo >> shift) | (hi << (32u - shift));
|
||||
}
|
||||
#endif
|
||||
|
||||
fn f16_from_u16(bits: u32) -> f16 {
|
||||
let packed = unpack2x16float(bits);
|
||||
return f16(packed[0]);
|
||||
}
|
||||
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
|
||||
for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_k_u16_at(block_byte_base));
|
||||
let thread_byte_offset = block_offset * K_BYTES_PER_THREAD;
|
||||
let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset;
|
||||
for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) {
|
||||
let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_k_u32_at(q_byte_offset);
|
||||
#if defined(K_Q4_0)
|
||||
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP);
|
||||
#elif defined(K_Q8_0)
|
||||
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(V_Q4_0) || defined(V_Q8_0)
|
||||
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
|
||||
for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES;
|
||||
let d = f16_from_u16(load_v_u16_at(block_byte_base));
|
||||
let thread_byte_offset = block_offset * V_BYTES_PER_THREAD;
|
||||
let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset;
|
||||
for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) {
|
||||
let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_v_u32_at(q_byte_offset);
|
||||
#if defined(V_Q4_0)
|
||||
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP);
|
||||
#elif defined(V_Q8_0)
|
||||
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -1,16 +1,29 @@
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#define BYTE_HELPERS
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef Q_F16
|
||||
#define Q_TYPE f16
|
||||
#else
|
||||
#define Q_TYPE f32
|
||||
#endif
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#ifdef K_F32
|
||||
#define K_TYPE f32
|
||||
#elif defined(K_Q4_0) || defined(K_Q8_0)
|
||||
#define K_TYPE u32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#define K_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef V_F32
|
||||
#define V_TYPE f32
|
||||
#elif defined(V_Q4_0) || defined(V_Q8_0)
|
||||
#define V_TYPE u32
|
||||
#else
|
||||
#define V_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef DST_F16
|
||||
@@ -21,7 +34,6 @@ enable subgroups;
|
||||
|
||||
#define HEAD_DIM_QK 64
|
||||
#define HEAD_DIM_V 64
|
||||
#define KV_STAGE_STRIDE 64
|
||||
#define Q_TILE 4
|
||||
#define KV_TILE 64
|
||||
#define WG_SIZE 128
|
||||
@@ -64,11 +76,23 @@ struct Params {
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
|
||||
#ifdef KV_OVERLAP
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
|
||||
#endif
|
||||
#define V K
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
|
||||
#endif
|
||||
#if defined(V_Q4_0) || defined(V_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@@ -121,10 +145,50 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
|
||||
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
|
||||
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
|
||||
const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
|
||||
|
||||
var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>;
|
||||
var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>;
|
||||
var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>;
|
||||
var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
|
||||
var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>;
|
||||
|
||||
#define QUANT_SHMEM kv_shmem
|
||||
#define QUANT_OUT_TYPE f16
|
||||
#include "quant_inner_loops.tmpl"
|
||||
#include "flash_attn_quant_staging.tmpl"
|
||||
|
||||
#if !defined(K_Q4_0) && !defined(K_Q8_0)
|
||||
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
|
||||
for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / Q_CHUNKS;
|
||||
let chunk = vec_idx_local % Q_CHUNKS;
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = f16(k4.x);
|
||||
kv_shmem[kv_off + 1u] = f16(k4.y);
|
||||
kv_shmem[kv_off + 2u] = f16(k4.z);
|
||||
kv_shmem[kv_off + 3u] = f16(k4.w);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(V_Q4_0) && !defined(V_Q8_0)
|
||||
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
|
||||
for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / V_CHUNKS;
|
||||
let chunk = vec_idx_local % V_CHUNKS;
|
||||
let global_v_row = kv_tile + kv_local;
|
||||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * HEAD_DIM_V + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = f16(v4.x);
|
||||
kv_shmem[kv_off + 1u] = f16(v4.y);
|
||||
kv_shmem[kv_off + 2u] = f16(v4.z);
|
||||
kv_shmem[kv_off + 3u] = f16(v4.w);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@@ -206,18 +270,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
local_scores[slot] = FLOAT_MIN;
|
||||
}
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / Q_CHUNKS;
|
||||
let chunk = vec_idx_local % Q_CHUNKS;
|
||||
let global_k_row = kv_tile + kv_local;
|
||||
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
|
||||
let k4 = K[k_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = KV_TYPE(k4.x);
|
||||
kv_shmem[kv_off + 1u] = KV_TYPE(k4.y);
|
||||
kv_shmem[kv_off + 2u] = KV_TYPE(k4.z);
|
||||
kv_shmem[kv_off + 3u] = KV_TYPE(k4.w);
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -238,8 +293,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
q_shmem[q_off + 1u],
|
||||
q_shmem[q_off + 2u],
|
||||
q_shmem[q_off + 3u]);
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let kv = vec4<KV_TYPE>(
|
||||
let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u;
|
||||
let kv = vec4<f16>(
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
@@ -271,25 +326,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
let kv_local = sg_inv_id + slot * subgroup_size;
|
||||
if (row_active && kv_local < kv_count) {
|
||||
let p = exp(local_scores[slot] - new_max);
|
||||
p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p);
|
||||
p_shmem[subgroup_p_offset + kv_local] = f16(p);
|
||||
local_sum += p;
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) {
|
||||
let kv_local = vec_idx_local / V_CHUNKS;
|
||||
let chunk = vec_idx_local % V_CHUNKS;
|
||||
let global_v_row = kv_tile + kv_local;
|
||||
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
|
||||
let v4 = V[v_vec_index];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
kv_shmem[kv_off + 0u] = KV_TYPE(v4.x);
|
||||
kv_shmem[kv_off + 1u] = KV_TYPE(v4.y);
|
||||
kv_shmem[kv_off + 2u] = KV_TYPE(v4.z);
|
||||
kv_shmem[kv_off + 3u] = KV_TYPE(v4.w);
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -306,14 +352,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
||||
var acc = out_regs[reg_idx];
|
||||
for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) {
|
||||
let p = p_shmem[subgroup_p_offset + kv_local];
|
||||
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
|
||||
let v4 = vec4<KV_TYPE>(
|
||||
let p = f32(p_shmem[subgroup_p_offset + kv_local]);
|
||||
let kv_off = kv_local * HEAD_DIM_V + chunk * 4u;
|
||||
let v4 = vec4<f16>(
|
||||
kv_shmem[kv_off + 0u],
|
||||
kv_shmem[kv_off + 1u],
|
||||
kv_shmem[kv_off + 2u],
|
||||
kv_shmem[kv_off + 3u]);
|
||||
acc += f32(p) * vec4<f32>(v4);
|
||||
acc += p * vec4<f32>(v4);
|
||||
}
|
||||
out_regs[reg_idx] = acc;
|
||||
}
|
||||
|
||||
@@ -2,10 +2,23 @@ diagnostic(off, subgroup_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
|
||||
#ifdef KV_F32
|
||||
#define KV_TYPE f32
|
||||
#define BYTE_HELPERS
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#ifdef K_F32
|
||||
#define K_TYPE f32
|
||||
#elif defined(K_Q4_0) || defined(K_Q8_0)
|
||||
#define K_TYPE u32
|
||||
#else
|
||||
#define KV_TYPE f16
|
||||
#define K_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef V_F32
|
||||
#define V_TYPE f32
|
||||
#elif defined(V_Q4_0) || defined(V_Q8_0)
|
||||
#define V_TYPE u32
|
||||
#else
|
||||
#define V_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef Q_F16
|
||||
@@ -32,28 +45,6 @@ enable subgroups;
|
||||
|
||||
#define KV_BLOCKS (KV_TILE / KV_GRANULARITY)
|
||||
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
|
||||
#if defined(KV_Q4_0)
|
||||
#define NQ 16
|
||||
#define F16_PER_BLOCK 9
|
||||
#define WEIGHTS_PER_F16 4
|
||||
#elif defined(KV_Q8_0)
|
||||
#define NQ 8
|
||||
#define F16_PER_BLOCK 17
|
||||
#define WEIGHTS_PER_F16 2
|
||||
#endif
|
||||
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
|
||||
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
||||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
|
||||
struct Params {
|
||||
offset_q: u32,
|
||||
offset_k: u32,
|
||||
@@ -103,22 +94,22 @@ struct Params {
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
|
||||
#ifdef KV_OVERLAP
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
|
||||
#endif
|
||||
#define V K
|
||||
#else
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
|
||||
#if defined(K_Q4_0) || defined(K_Q8_0)
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
|
||||
@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>;
|
||||
#endif
|
||||
#if defined(KV_Q4_0) || defined(KV_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
|
||||
#if defined(V_Q4_0) || defined(V_Q8_0)
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
|
||||
#else
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
|
||||
@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>;
|
||||
#endif
|
||||
#endif
|
||||
#if defined(MASK) && defined(SINKS)
|
||||
@@ -244,6 +235,49 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool)
|
||||
return v;
|
||||
}
|
||||
|
||||
#ifndef KV_DIRECT
|
||||
#define QUANT_SHMEM kv_shmem
|
||||
#define QUANT_OUT_TYPE f32
|
||||
#include "quant_inner_loops.tmpl"
|
||||
#include "flash_attn_quant_staging.tmpl"
|
||||
|
||||
#if !defined(K_Q4_0) && !defined(K_Q8_0)
|
||||
fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) {
|
||||
for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f32(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(k4.w);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(V_Q4_0) && !defined(V_Q8_0)
|
||||
fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) {
|
||||
for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f32(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(v4.w);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@@ -308,6 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
}
|
||||
|
||||
for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
|
||||
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
|
||||
#ifdef BLK
|
||||
let q_blk = q_row_start;
|
||||
let kv_blk = kv_tile / KV_TILE;
|
||||
@@ -324,76 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
}
|
||||
|
||||
// load k tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let k_row = blck_idx / BLOCKS_K;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let row_offset = k_row * HEAD_DIM_QK;
|
||||
|
||||
if (global_k_row < params.seq_len_kv) {
|
||||
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = K[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = K[base_idx + 1u + block_offset + j];
|
||||
let q_1 = K[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
|
||||
let k_row = elem_idx / HEAD_DIM_QK;
|
||||
let k_col = elem_idx % HEAD_DIM_QK;
|
||||
let global_k_row = kv_tile + k_row;
|
||||
let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
|
||||
let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
|
||||
let vec_idx = (global_k_row_offset + k_col) >> 2u;
|
||||
let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f32(k4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(k4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(k4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(k4.w);
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
@@ -510,76 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
}
|
||||
|
||||
// load v tile into shared memory
|
||||
#if defined(KV_Q4_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d);
|
||||
let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_lo;
|
||||
kv_shmem[row_offset + idx + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_Q8_0)
|
||||
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
|
||||
let blck_idx = elem_idx / BLOCK_SIZE;
|
||||
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
|
||||
let v_row = blck_idx / BLOCKS_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let block_k = blck_idx % BLOCKS_V;
|
||||
let row_offset = v_row * HEAD_DIM_V;
|
||||
|
||||
if (global_v_row < params.seq_len_kv) {
|
||||
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
|
||||
let base_idx = global_block_idx * F16_PER_BLOCK;
|
||||
let d = V[base_idx];
|
||||
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
|
||||
let q_0 = V[base_idx + 1u + block_offset + j];
|
||||
let q_1 = V[base_idx + 1u + block_offset + j + 1];
|
||||
let q_packed = bitcast<u32>(vec2(q_0, q_1));
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = f32(q_byte) * f32(d);
|
||||
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
|
||||
kv_shmem[row_offset + idx] = q_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(KV_DIRECT)
|
||||
// Direct global loads for KV
|
||||
#else
|
||||
for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
|
||||
let v_row = elem_idx / HEAD_DIM_V;
|
||||
let v_col = elem_idx % HEAD_DIM_V;
|
||||
let global_v_row = kv_tile + v_row;
|
||||
let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
|
||||
let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
|
||||
let vec_idx = (global_v_row_offset + v_col) >> 2u;
|
||||
let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
|
||||
kv_shmem[elem_idx + 0u] = f32(v4.x);
|
||||
kv_shmem[elem_idx + 1u] = f32(v4.y);
|
||||
kv_shmem[elem_idx + 2u] = f32(v4.z);
|
||||
kv_shmem[elem_idx + 3u] = f32(v4.w);
|
||||
}
|
||||
#ifndef KV_DIRECT
|
||||
load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset);
|
||||
#endif
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
@@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) {
|
||||
}
|
||||
#endif // SCALAR
|
||||
|
||||
#define QUANT_SHMEM shmem
|
||||
#define QUANT_OUT_TYPE f16
|
||||
#include "quant_inner_loops.tmpl"
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
@@ -124,14 +128,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
|
||||
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -314,12 +311,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
|
||||
}
|
||||
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
#ifdef U32_DEQUANT_HELPERS
|
||||
fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) {
|
||||
let scale = QUANT_OUT_TYPE(d);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale;
|
||||
let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale;
|
||||
QUANT_SHMEM[dst_idx + k] = q_lo;
|
||||
QUANT_SHMEM[dst_idx + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
|
||||
fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) {
|
||||
let scale = QUANT_OUT_TYPE(d);
|
||||
for (var k = 0u; k < 4u; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
let q_val = QUANT_OUT_TYPE(q_byte) * scale;
|
||||
QUANT_SHMEM[dst_idx + k] = q_val;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user