forked from wylab/llama.cpp
cpu: fix ARM NEON nvfp4 vec dot
This commit is contained in:
@@ -672,34 +672,36 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
float32x4_t acc = vdupq_n_f32(0.0f);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs);
|
||||
const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8);
|
||||
const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16);
|
||||
const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24);
|
||||
const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs);
|
||||
const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8);
|
||||
const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16);
|
||||
const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24);
|
||||
|
||||
const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs);
|
||||
const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16);
|
||||
|
||||
const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_0, m4b));
|
||||
const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4));
|
||||
const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b));
|
||||
const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
|
||||
|
||||
const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
|
||||
const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
|
||||
const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
|
||||
const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b));
|
||||
const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0);
|
||||
const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0);
|
||||
const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0);
|
||||
const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0);
|
||||
const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1);
|
||||
const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1);
|
||||
const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1);
|
||||
const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1);
|
||||
|
||||
const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs);
|
||||
const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16);
|
||||
const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b));
|
||||
const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
|
||||
const int32x4_t p0 = ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi);
|
||||
const int32x4_t p1 = ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi);
|
||||
const int32x4_t p2 = ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi);
|
||||
const int32x4_t p3 = ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi);
|
||||
|
||||
const int32x4_t p0 = vaddq_s32(
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
|
||||
const int32x4_t p1 = vaddq_s32(
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
|
||||
ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
|
||||
|
||||
const int32x4_t sums = vpaddq_s32(p0, p1);
|
||||
|
||||
// Decode 4 UE4M3 scales to f32 and multiply with q8 scales
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
const float32x4_t nvsc = {
|
||||
@@ -710,7 +712,13 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
};
|
||||
const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
|
||||
|
||||
acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
|
||||
const float32x4_t sums = (float32x4_t){
|
||||
(float)vaddvq_s32(p0),
|
||||
(float)vaddvq_s32(p1),
|
||||
(float)vaddvq_s32(p2),
|
||||
(float)vaddvq_s32(p3)
|
||||
};
|
||||
acc = vfmaq_f32(acc, sums, scales);
|
||||
}
|
||||
sumf = vaddvq_f32(acc);
|
||||
#else
|
||||
|
||||
@@ -319,6 +319,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
|
||||
#endif // !defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo,
|
||||
const int8x8_t q4_hi, const int8x8_t q8_hi) {
|
||||
const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo);
|
||||
const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi);
|
||||
const int32x4_t sum_lo = vpaddlq_s16(p_lo);
|
||||
const int32x4_t sum_hi = vpaddlq_s16(p_hi);
|
||||
return vaddq_s32(sum_lo, sum_hi);
|
||||
}
|
||||
|
||||
#endif // defined(__ARM_NEON)
|
||||
|
||||
#ifdef __wasm_simd128__
|
||||
|
||||
Reference in New Issue
Block a user