ggml : implement fast walsh-hadamard transform for kv rotation (#21352) (#22631)

This commit is contained in:
Ismail
2026-05-05 04:05:05 +02:00
committed by GitHub
parent eff06702b2
commit a817a22bc6
8 changed files with 183 additions and 0 deletions
+65
View File
@@ -3952,6 +3952,59 @@ struct test_mul_mat : public test_case {
}
};
// GGML_HINT_SRC0_IS_HADAMARD
struct test_mul_mat_hadamard : public test_mul_mat {
test_mul_mat_hadamard(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {1, 1},
std::array<int64_t, 2> nr = {1, 1})
: test_mul_mat(type_a, type_b, m, n, k, bs, nr) {
GGML_ASSERT(type_a == GGML_TYPE_F32);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * out = test_mul_mat::build_graph(ctx);
// Find the mul_mat op in the graph and set the hint
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->op == GGML_OP_MUL_MAT) {
ggml_mul_mat_set_hint(t, GGML_HINT_SRC0_IS_HADAMARD);
}
}
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (strcmp(t->name, "a") == 0) {
const int64_t n_cols = t->ne[0];
const int64_t n_rows = ggml_nrows(t);
std::vector<float> data(n_cols * n_rows);
float scale = 1.0f / sqrtf((float)n_cols);
for (int64_t r = 0; r < n_rows; r++) {
float * row_data = data.data() + r * n_cols;
for (int64_t i = 0; i < n_cols; i++) {
int pop = 0;
int64_t val = r & i;
while (val) {
pop += (val & 1);
val >>= 1;
}
row_data[i] = (pop % 2 == 0) ? scale : -scale;
}
}
ggml_backend_tensor_set(t, data.data(), 0, data.size() * sizeof(float));
} else if (t->type == GGML_TYPE_F32 || t->type == GGML_TYPE_F16) {
init_tensor_uniform(t);
}
}
}
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "MUL_MAT_HADAMARD";
}
};
static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
std::random_device rd;
std::default_random_engine rng(rd());
@@ -8063,6 +8116,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
// FWHT tests
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
#if 0
// > 4GB A matrix. Too slow to be enabled by default.
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
@@ -8917,6 +8976,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
// FWHT tests
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 1, 128));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 64, 1, 64));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 256, 1, 256));
test_cases.emplace_back(new test_mul_mat_hadamard(GGML_TYPE_F32, GGML_TYPE_F32, 128, 32, 128));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
// qwen3next with CHUNK_SIZE 64