add dot_product abstraction to reduce preprocessor branching

This commit is contained in:
Ruben Ortlam
2026-06-05 13:23:43 +02:00
parent 4b2739e846
commit 62a989f9c1
4 changed files with 33 additions and 38 deletions
@@ -1,7 +0,0 @@
#ifdef DOT2_F16
#extension GL_EXT_spirv_intrinsics : require
spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"],
capabilities = [6912], id = 6916)
float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc);
#endif
@@ -0,0 +1,27 @@
#ifdef DOT2_F16
#extension GL_EXT_spirv_intrinsics : require
spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"],
capabilities = [6912], id = 6916)
float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc);
ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc))));
}
ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc)));
}
#else
ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y),
fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc))));
}
ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc));
}
#endif
@@ -21,7 +21,7 @@
#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "dot2_f16.glsl"
#include "dot_product_funcs.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
@@ -319,12 +319,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
#ifdef DOT2_F16
Sf[r][c] = ACC_TYPE(v_dot2_f32_f16(Q_cache[r].zw, K_Tf.zw,
v_dot2_f32_f16(Q_cache[r].xy, K_Tf.xy, float(Sf[r][c]))));
#else
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
#endif
Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]);
}
}
}
@@ -347,12 +342,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
#ifdef DOT2_F16
Sf[r][c] = ACC_TYPE(v_dot2_f32_f16(f16vec2(Qf[tile_row(r) * qf_stride + d * D_split + d_tid].zw), K_Tf.zw,
v_dot2_f32_f16(f16vec2(Qf[tile_row(r) * qf_stride + d * D_split + d_tid].xy), K_Tf.xy, float(Sf[r][c]))));
#else
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
#endif
Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]);
}
}
}
@@ -29,7 +29,7 @@
#endif
#include "types.glsl"
#include "dot2_f16.glsl"
#include "dot_product_funcs.glsl"
#ifndef LOAD_VEC_A
#define LOAD_VEC_A 1
@@ -330,23 +330,8 @@ void main() {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#if defined(DOT2_F16) && (defined(DATA_A_F32) || defined(DATA_A_F16))
float dot2_x = v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr ].xy, cache_b.xy, float(sums[sums_idx].x));
sums[sums_idx].x = ACC_TYPE(v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr ].zw, cache_b.zw, dot2_x));
float dot2_y = v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr + 1].xy, cache_b.xy, float(sums[sums_idx].y));
sums[sums_idx].y = ACC_TYPE(v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr + 1].zw, cache_b.zw, dot2_y));
#elif defined(DATA_A_F32) || defined(DATA_A_F16)
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
#elif defined(DOT2_F16)
sums[sums_idx].x = ACC_TYPE(v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr ], cache_b, float(sums[sums_idx].x)));
sums[sums_idx].y = ACC_TYPE(v_dot2_f32_f16(cache_a[wsir * TM + 2 * cr + 1], cache_b, float(sums[sums_idx].y)));
#else
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
#endif
sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x);
sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y);
}
}
}