Compare commits

...

3 Commits

Author SHA1 Message Date
Xuan-Son Nguyen e3af5563bd llama: store mrope data in KV cell (#16825)
* llama: store mrope data in KV cell

* correct x,y ordering

* address review comments

* add consistency checks

* Update src/llama-kv-cache.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* add TODO

* fix asan error

* kv-cells : improve ext handling

* cont : fix headers

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-10-29 18:09:18 +01:00
Jeff Bolz 10fcc41290 vulkan: Update topk_moe fusion to handle gpt's late softmax (#16656)
* vulkan: Update topk_moe fusion to handle gpt's late softmax

Based on #16649.

* Add ggml_check_edges

* Add sync logging to show fusion effects

* handle clamp added in #16655

* Update ggml/src/ggml-impl.h

Co-authored-by: Diego Devesa <slarengh@gmail.com>
2025-10-29 14:44:29 +01:00
Ruben Ortlam bcf5bda6f5 Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
* vulkan: add mmq q2_k integer dot support

* Refactor mmq caching

* Reduce mmq register use

* Load 4 quant blocks into shared memory in one step

* Pack q2_k blocks into caches of 32

* Use 32-bit accumulators for integer dot matmul

* Add q4_k mmq

* Add q3_k mmq

* Add q5_k mmq

* Add q6_k mmq

* Add mxfp4 mmq, enable MMQ MUL_MAT_ID

* Fix mmv dm loads
2025-10-29 14:39:03 +01:00
26 changed files with 1360 additions and 592 deletions
+16
View File
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
#endif
#ifdef __cplusplus
#include <array>
#include <initializer_list>
#include <vector>
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
}
// Return true if the edges in the graph match expectations.
inline bool ggml_check_edges(const struct ggml_cgraph * cgraph,
int start_idx,
std::initializer_list<std::array<int, 3>> edges) {
for (const auto & edge : edges) {
int dst_node = edge[0];
int src_idx = edge[1];
int src_node = edge[2];
if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
return false;
}
}
return true;
}
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+335 -134
View File
@@ -385,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
{ 2, 0, 0 }, // argsort->src[0] == softmax
{ 3, 0, 2 }, // view->src[0] == argsort
{ 4, 0, 1 }, // get_rows->src[0] == reshape
{ 4, 1, 3 }, // get_rows->src[1] == view
{ 5, 0, 4 }, // reshape->src[0] == get_rows
{ 6, 0, 5 }, // sum_rows->src[0] == reshape
{ 7, 0, 6 }, // clamp->src[0] == sum_rows
{ 8, 0, 5 }, // div->src[0] == reshape
{ 8, 1, 7 }, // div->src[1] == clamp
{ 9, 0, 8 }, // reshape->src[0] == div
};
// same as early_softmax_norm but ending after the get_rows
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax
{ 2, 0, 0 }, // argsort->src[0] == softmax
{ 3, 0, 2 }, // view->src[0] == argsort
{ 4, 0, 1 }, // get_rows->src[0] == reshape
{ 4, 1, 3 }, // get_rows->src[1] == view
};
//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
{ 1, 0, 0 }, // view->src[0] == argsort
{ 2, 1, 1 }, // get_rows->src[1] == view
{ 3, 0, 2 }, // reshape->src[0] == get_rows
{ 4, 0, 3 }, // soft_max->src[0] == reshape
{ 5, 0, 4 }, // reshape->src[0] == soft_max
};
enum topk_moe_mode {
TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_LATE_SOFTMAX,
TOPK_MOE_COUNT,
};
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
TOPK_MOE_LATE_SOFTMAX;
return mode;
}
struct vk_device_struct {
std::recursive_mutex mutex;
@@ -486,6 +550,7 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
@@ -604,8 +669,7 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
std::vector<vk_pipeline_ref> all_pipelines;
@@ -953,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
float clamp_min;
float clamp_max;
};
struct vk_op_add_id_push_constants {
@@ -2448,8 +2514,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
@@ -2512,10 +2581,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
@@ -2524,10 +2599,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2912,18 +2995,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
@@ -2962,11 +3042,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
}
#endif
@@ -2996,6 +3084,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
}
#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
@@ -3022,6 +3128,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
}
#endif
}
#undef CREATE_MM2
#undef CREATE_MMQ
@@ -3086,6 +3210,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif
@@ -3145,7 +3275,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
#endif
) {
@@ -3739,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
}
for (auto &c : compiles) {
@@ -4928,7 +5059,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
@@ -5075,6 +5206,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
}
return pipelines;
}
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
@@ -6877,10 +7019,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
if (mmp == nullptr) {
// Fall back to f16 dequant mul mat
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
quantize_y = false;
}
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
@@ -6890,8 +7041,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
@@ -6904,12 +7055,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
@@ -6924,9 +7076,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
}
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
if (
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -6935,7 +7094,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
@@ -6947,6 +7106,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
return;
}
@@ -6983,6 +7145,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@@ -7014,6 +7179,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ctx->prealloc_y_last_tensor_used = src1;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@@ -7022,14 +7198,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
uint32_t y_sz_total = y_sz * ne12 * ne13;
if (quantize_y) {
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}
// compute
ggml_vk_matmul_id(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
@@ -7968,8 +8149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
return ctx->device->pipeline_topk_moe[idx][with_norm];
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8024,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
}
case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
return ctx->device->pipeline_topk_moe[idx][mode];
}
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
return ctx->device->pipeline_argsort_f32[idx];
@@ -9563,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
(mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
cgraph->nodes[node_idx + 5];
ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9625,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
GGML_ASSERT(d_ids != nullptr);
}
vk_op_topk_moe_push_constants pc;
vk_op_topk_moe_push_constants pc {};
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
GGML_ASSERT(n_expert_used <= n_experts);
@@ -11222,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
#define ENABLE_SYNC_LOGGING 0
if (need_sync) {
#if ENABLE_SYNC_LOGGING
std::cerr << "sync" << std::endl;
#endif
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11240,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
}
}
}
#if ENABLE_SYNC_LOGGING
if (!dryrun) {
for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
auto *n = cgraph->nodes[node_idx + i];
std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
if (n->op == GGML_OP_GLU) {
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
}
std::cerr << std::endl;
}
}
#endif
switch (node->op) {
case GGML_OP_REPEAT:
@@ -11418,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ARGSORT:
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
}
break;
case GGML_OP_SUM:
@@ -12191,31 +12408,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, bool with_norm) {
int node_idx, topk_moe_mode mode) {
if (with_norm) {
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
const ggml_tensor * softmax;
const ggml_tensor * weights;
switch (mode) {
case TOPK_MOE_EARLY_SOFTMAX_NORM:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 9];
break;
case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4];
break;
case TOPK_MOE_LATE_SOFTMAX:
softmax = cgraph->nodes[node_idx + 4];
weights = cgraph->nodes[node_idx + 5];
break;
default:
return false;
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
@@ -12240,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
return false;
}
// Check that the nodes don't have any unexpected uses
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
get_rows->src[0] != reshape1) {
return false;
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
view->src[0] != argsort) {
return false;
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
sum_rows->src[0] != reshape5 ||
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
// div/reshape are written
if (reshape8->src[0] != div) {
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
@@ -12379,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12480,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
}
}
@@ -12615,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
// Avoid reordering topk_moe_norm
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
bool is_topk_moe_norm = true;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
is_topk_moe_norm = false;
// Check for fusion patterns and avoid reordering them
auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
if (start + (int)pattern.size() <= graph->n_nodes) {
bool is_pattern = true;
for (size_t j = 0; j < pattern.size(); ++j) {
if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
is_pattern = false;
}
}
return is_pattern;
}
if (is_topk_moe_norm) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
return false;
};
auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
if (match_pattern(pattern, first_unused)) {
for (size_t j = 0; j < pattern.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
continue;
return true;
}
return false;
};
if (keep_pattern(topk_moe_early_softmax_norm)) {
continue;
}
if (keep_pattern(topk_moe_early_softmax)) {
continue;
}
if (keep_pattern(topk_moe_late_softmax)) {
continue;
}
// First, grab the next unused node.
current_set.push_back(first_unused);
@@ -12651,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (is_empty(graph->nodes[j])) {
continue;
}
// Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
const vec2 dm = vec2(data_a[a_offset + ib].dm);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif
@@ -26,7 +26,7 @@ void main() {
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}
@@ -24,8 +24,8 @@ void main() {
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
@@ -20,8 +20,8 @@ void main() {
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
@@ -19,8 +19,8 @@ void main() {
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
const vec2 d = vec2(data_a[ib].d);
const vec2 dm = vec2(data_a[ib].dm);
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[ib].d);
const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
const float d = e8m0_to_fp32(data_a[ib].e);
const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
@@ -0,0 +1,70 @@
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[BN];
uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0;
uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0);
uint ids[16];
uint iter = 0;
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
// prefetch up to 16 elements
if (iter == 0) {
[[unroll]] for (uint k = 0; k < 16; ++k) {
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
}
}
uint i = j + gl_LocalInvocationIndex;
bool in_range = i < num_elements;
uint ii1;
if (nei0_is_pow2) {
ii1 = i >> nei0shift;
} else {
ii1 = i / p.nei0;
}
uint ii0 = i - ii1 * p.nei0;
uint id = ids[iter++];
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
ballots_sh[gl_SubgroupID] = ballot;
barrier();
uint subgroup_base = 0;
uint total = 0;
for (uint k = 0; k < gl_NumSubgroups; ++k) {
if (k == gl_SubgroupID) {
subgroup_base = total;
}
total += subgroupBallotBitCount(ballots_sh[k]);
}
barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
}
_ne1 += total;
iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
}
barrier();
}
#endif // MUL_MAT_ID_USE_SUBGROUPS
#endif // MUL_MAT_ID
+77 -227
View File
@@ -10,10 +10,9 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
@@ -24,7 +23,10 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
@@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#define MMQ_SHMEM
#include "mul_mmq_shmem_types.glsl"
#ifndef BK_STEP
#define BK_STEP 4
#endif
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];
// Register cache
block_a_cache cache_a[WMITER * TM];
block_b_cache cache_b;
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
@@ -139,26 +128,12 @@ void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +147,27 @@ void main() {
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
#ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true, ic);
} else {
load_row_ids(expert_idx, false, ic);
}
#else
_ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++;
}
}
}
barrier();
#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
@@ -209,159 +194,70 @@ void main() {
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
const uint iqs = loadr_a;
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[buf_ib];
const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
const uint iqs = loadr_b;
[[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
pos_a_ib += BK_STEP;
pos_b_ib += BK_STEP;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint reg_ib = wsir * TM + cr;
const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
block_b_to_registers(ib);
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] += mmq_dot_product(cache_a_idx);
}
}
}
}
}
#endif
barrier();
}
@@ -373,54 +269,6 @@ void main() {
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
@@ -431,19 +279,21 @@ void main() {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
if (dr_warp + cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}
@@ -6,41 +6,89 @@
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
#else // DATA_A_Q4_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
}
#else // DATA_A_Q5_1
buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
buf_a[buf_ib].qh = data_a_packed32[ib].qh;
}
#endif
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].qh = buf_a[buf_ib].qh;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const uint32_t vui = cache_a[ib_a].qs[iqs];
const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
const int32_t qs_b0 = cache_b.qs[iqs];
const int32_t qs_b1 = cache_b.qs[iqs + 4];
q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
const int32_t qs_b = cache_b.qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
if (iqs == 0) {
buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d = buf_a[buf_ib].d;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
// Repack 4x4 quants into one int
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
if (iqs == 0) {
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
cache_a[reg_ib].scales = buf_a[buf_ib].scales;
[[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t sum_d = 0;
int32_t sum_m = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
const uint iqs_k = (ib % 8) * 8 + hm_idx;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// Repack 2x4 quants into one int
// Add the 3rd bit instead of subtracting it to allow packing the quants
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
(pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
// Subtract 4 from the quants to correct the 3rd bit offset
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
// Repack 2x4 quants into one int
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs * QUANT_R_MMQ;
const uint qh_shift = iqs_k / 8;
buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].dm = buf_a[buf_ib].dm;
[[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
#if defined(DATA_A_Q4_K)
const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
#else // defined(DATA_A_Q5_K)
const int32_t qs_a = cache_a[ib_a].qs[iqs];
#endif
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
if (iqs == 0) {
const uint is = iqs_k / 4;
const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
}
}
void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
[[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
}
}
ACC_TYPE mmq_dot_product(const uint ib_a) {
float result = 0.0;
int32_t q_sum = 0;
[[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
q_sum = 0;
[[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
const int32_t qs_a = cache_a[ib_a].qs[iqs];
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif
@@ -0,0 +1,78 @@
#if defined(DATA_A_Q4_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q4_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_0)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE dm;
};
#elif defined(DATA_A_Q5_1)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[16/4];
uint32_t qh;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q8_0)
#define QUANT_R_MMQ 1
// AMD likes 4, Intel likes 1 and Nvidia likes 2
#define BK_STEP 1
struct block_a_cache {
int32_t qs[32/4];
FLOAT_TYPE dm;
};
#elif defined(DATA_A_MXFP4)
#define QUANT_R_MMQ 2
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE d;
};
#elif defined(DATA_A_Q2_K)
#define QUANT_R_MMQ 4
struct block_a_cache {
uint32_t qs[2];
u8vec2 scales;
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q3_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 d_scales;
};
#elif defined(DATA_A_Q4_K)
#define QUANT_R_MMQ 2
struct block_a_cache {
uint32_t qs[4];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q5_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 dm;
};
#elif defined(DATA_A_Q6_K)
#define QUANT_R_MMQ 1
struct block_a_cache {
int32_t qs[8];
FLOAT_TYPE_VEC2 d_scales;
};
#endif
struct block_b_cache
{
int32_t qs[8];
FLOAT_TYPE_VEC2 ds;
};
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
float clamp_min;
float clamp_max;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
@@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0;
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
float max_val = -INFINITY;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
max_val = max(max_val, vals[i]);
}
}
max_val = subgroupMax(max_val);
float sum = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
const float val = exp(vals[i] - max_val);
vals[i] = val;
sum += val;
} else {
vals[i] = 0.f;
}
}
sum = subgroupAdd(sum);
const float inv_sum = 1.0f / sum;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const uint idx = lane + i * WARP_SIZE;
const bool is_active = !use_limit || (idx < limit);
if (is_active) {
vals[i] *= inv_sum;
}
}
}
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
@@ -35,43 +84,16 @@ void main() {
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
float wt[experts_per_thread];
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
const uint expert = i + gl_LocalInvocationID.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
if (!late_softmax) {
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
}
// at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
float output_weights[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
output_weights[i] = 0.f;
}
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
@@ -129,6 +157,10 @@ void main() {
}
}
if (late_softmax) {
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
+33 -20
View File
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
f16vec2 d;
f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
f16vec2 d;
f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#define SCALES_PER_32 2
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct block_q4_K
{
f16vec2 d;
f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
f16vec2 d;
f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
f16vec2 d;
f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct block_q5_K
{
f16vec2 d;
f16vec2 dm;
uint8_t scales[12];
uint8_t qh[QUANT_K_Q5_K/8];
uint8_t qs[QUANT_K_Q5_K/2];
@@ -324,12 +333,20 @@ struct block_q5_K
struct block_q5_K_packed16
{
f16vec2 d;
f16vec2 dm;
uint16_t scales[12/2];
uint16_t qh[QUANT_K_Q5_K/8/2];
uint16_t qs[QUANT_K_Q5_K/2/2];
};
struct block_q5_K_packed32
{
f16vec2 dm;
uint32_t scales[12/4];
uint32_t qh[QUANT_K_Q5_K/8/4];
uint32_t qs[QUANT_K_Q5_K/2/4];
};
struct block_q5_K_packed128
{
uvec4 q5k[11];
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#define A_TYPE_PACKED32 block_q5_K_packed32
#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
{
uint16_t ql[QUANT_K_Q6_K/2/2];
uint16_t qh[QUANT_K_Q6_K/4/2];
int8_t scales[QUANT_K_Q6_K/16];
int16_t scales[QUANT_K_Q6_K/16/2];
float16_t d;
};
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#define DATA_A_QUANT_K
#endif
// IQuants
@@ -1363,18 +1383,11 @@ struct block_mxfp4
uint8_t qs[QUANT_K_MXFP4/2];
};
//struct block_mxfp4_packed16
//{
// uint8_t e;
// uint16_t qs[QUANT_K_MXFP4/2/2];
//};
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
#endif
#if defined(DATA_A_MXFP4)
const FLOAT_TYPE kvalues_mxfp4_const[16] = {
FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
const int8_t kvalues_mxfp4_const[16] = {
int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
};
shared FLOAT_TYPE kvalues_mxfp4[16];
shared int8_t kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
+42 -27
View File
@@ -215,6 +215,7 @@ bool llama_batch_allocr::init(
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token,
/*.embd =*/ batch.embd,
/*.pos =*/ batch.pos,
@@ -251,45 +252,57 @@ bool llama_batch_allocr::init(
// consistency checks
//
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (batch.token) {
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
}
if (n_pos_per_embd > 1) {
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
if (!ok) {
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
" for M-RoPE, it is required that the position satisfies: X < Y\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
} else {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}
}
@@ -389,6 +402,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
@@ -710,6 +724,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
+12 -1
View File
@@ -17,6 +17,16 @@ struct llama_ubatch {
return b_equal_seqs != 0;
}
// typical for M-RoPE cases:
// 0 - sequantial position of the tokens/embeddings in the sequence
// 1 - y position in the image
// 2 - x position in the image
// 3 - other
bool is_pos_2d() const {
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
return n_pos >= 3;
}
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
@@ -25,6 +35,7 @@ struct llama_ubatch {
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
uint32_t n_pos; // number of position inputs for each token/embedding
// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -33,7 +44,7 @@ struct llama_ubatch {
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
llama_pos * pos; // [n_tokens*n_pos] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
+32
View File
@@ -338,6 +338,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
llama_pos pos = v_cells[s0].pos_get(i);
llama_pos shift = v_cells[s0].get_shift(i);
llama_kv_cell_ext ext = v_cells[s0].ext_get(i);
if (shift != 0) {
pos -= shift;
assert(pos >= 0);
@@ -349,6 +351,8 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll
if (shift != 0) {
v_cells[s1].pos_add(i, shift);
}
v_cells[s1].ext_set(i, ext);
}
}
@@ -383,6 +387,7 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1");
auto & cells = v_cells[seq_to_stream[seq_id]];
auto & head = v_heads[seq_to_stream[seq_id]];
@@ -427,6 +432,7 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll
void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1");
auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -900,6 +906,14 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
cells.pos_set(idx, ubatch.pos[i]);
if (ubatch.is_pos_2d()) {
llama_kv_cell_ext ext {
/*.x =*/ ubatch.pos[i + ubatch.n_tokens*2],
/*.y =*/ ubatch.pos[i + ubatch.n_tokens],
};
cells.ext_set(idx, ext);
}
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
cells.seq_add(idx, ubatch.seq_id[i][s]);
}
@@ -1247,6 +1261,11 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
for (uint32_t j = 0; j < n_kv; ++j) {
@@ -1266,6 +1285,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
continue;
}
// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
@@ -1559,6 +1586,9 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t
io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id));
// TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
for (const auto & seq_id : seq_ids) {
io.write(&seq_id, sizeof(seq_id));
}
@@ -1704,6 +1734,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
return false;
}
// TODO: we cannot yet restore llama_kv_cell_ext as the apply_ubatch() does not support it yet
// see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
apply_ubatch(sinfo, ubatch);
const auto head_cur = sinfo.head();
+44 -2
View File
@@ -5,9 +5,27 @@
#include <bitset>
#include <cassert>
#include <vector>
#include <set>
#include <cstring>
#include <map>
#include <set>
#include <vector>
struct llama_kv_cell_ext {
// 2D spatial positions, typically used for M-RoPE
llama_pos x = 0;
llama_pos y = 0;
// return true if the current 2D spatial position is greater than other
bool is_2d_gt(llama_pos ox, llama_pos oy) const {
return (y > oy) || (y == oy && x > ox);
}
void reset() {
static_assert(std::is_trivially_copyable_v<llama_kv_cell_ext>);
memset(this, 0, sizeof(*this));
}
};
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
@@ -16,6 +34,7 @@ public:
void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) {
pos[i] = -1;
ext[i].reset();
shift[i] = 0;
seq[i].reset();
}
@@ -43,6 +62,7 @@ public:
void resize(uint32_t n) {
pos.resize(n);
ext.resize(n);
shift.resize(n);
seq.resize(n);
@@ -108,6 +128,7 @@ public:
const auto idx = i + j;
res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
@@ -126,6 +147,7 @@ public:
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.ext[j] = ext[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
@@ -154,6 +176,7 @@ public:
}
pos[idx] = other.pos[j];
ext[idx] = other.ext[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
@@ -184,6 +207,7 @@ public:
}
pos[idx] = other.pos[j];
ext[idx] = other.ext[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
@@ -203,6 +227,7 @@ public:
seq[i].reset();
pos[i] = -1;
ext[i].reset();
shift[i] = 0;
used.erase(i);
@@ -221,6 +246,7 @@ public:
if (seq[i].none()) {
pos[i] = -1;
ext[i].reset();
shift[i] = 0;
used.erase(i);
@@ -250,6 +276,7 @@ public:
seq[i].reset();
pos[i] = -1;
ext[i].reset();
shift[i] = 0;
used.erase(i);
@@ -340,6 +367,13 @@ public:
return pos[i];
}
const llama_kv_cell_ext & ext_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return ext[i];
}
// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
@@ -368,6 +402,11 @@ public:
used.insert(i);
}
void ext_set(uint32_t i, llama_kv_cell_ext p) {
assert(i < ext.size());
ext[i] = p;
}
// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
@@ -424,6 +463,9 @@ private:
std::vector<llama_pos> pos;
// stores extra info per cell
std::vector<llama_kv_cell_ext> ext;
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
+12 -1
View File
@@ -5,6 +5,15 @@
#include "llama.h"
// fix problem with std::min and std::max
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
# define NOMINMAX
#endif
#include <windows.h>
#endif
#include <algorithm>
#include <cerrno>
#include <cstdio>
@@ -1031,7 +1040,9 @@ const char * mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
llama_pos mtmd_image_tokens_get_n_pos(const mtmd_image_tokens * image_tokens) {
if (image_tokens->use_mrope_pos) {
return 1; // for M-RoPE, the whole image is 1 in temporal dimension
// for M-RoPE, temporal dimension = max(t,h,w)
// t is omitted as we don't support video input
return std::max(image_tokens->nx, image_tokens->ny);
}
return image_tokens->n_tokens();
}
+2 -2
View File
@@ -153,7 +153,7 @@ MTMD_API const mtmd_image_tokens * mtmd_input_chunk_get_tokens_image(const mtmd
MTMD_API size_t mtmd_input_chunk_get_n_tokens (const mtmd_input_chunk * chunk);
// returns nullptr for ID on text chunk
MTMD_API const char * mtmd_input_chunk_get_id (const mtmd_input_chunk * chunk);
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_input_chunk_get_n_pos (const mtmd_input_chunk * chunk);
// in case you want to use custom logic to handle the chunk (i.e. KV cache management)
@@ -171,7 +171,7 @@ MTMD_API size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * i
MTMD_API size_t mtmd_image_tokens_get_nx (const mtmd_image_tokens * image_tokens);
MTMD_API size_t mtmd_image_tokens_get_ny (const mtmd_image_tokens * image_tokens);
MTMD_API const char * mtmd_image_tokens_get_id (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// number of temporal positions (always 1 for M-RoPE, n_tokens otherwise)
// number of temporal positions (equals to max(t,h,w) for M-RoPE; equals to n_tokens otherwise)
MTMD_API llama_pos mtmd_image_tokens_get_n_pos (const mtmd_image_tokens * image_tokens); // TODO: deprecate
// tokenize an input text prompt and a list of bitmaps (images/audio)