feat: TurboQuant KV-cache quantization for AMD ROCm (turbo3/turbo4)
Implements TurboQuant (Zandieh et al., ICLR 2026) KV-cache vector quantization targeting AMD RDNA 4 (gfx1201, RX 9070 XT). Algorithm: L2-normalize → FWHT(128) → Lloyd-Max scalar quantize → bitpack Decode: unpack → codebook lookup → inverse FWHT → denormalize Two new GGML types: - GGML_TYPE_TURBO3_0: 3-bit, 3.5 bpw, MSE*d=0.034 (block_size=32, 14 bytes) - GGML_TYPE_TURBO4_0: 4-bit, 4.5 bpw, MSE*d=0.009 (block_size=32, 18 bytes) Architecture (pre-dequantize strategy): - Write path: FWHT-aware set-rows kernels (128 threads, shared-mem FWHT) - Read path: bulk dequantize turbo→f16 before standard Flash Attention - Stride scaling preserves ggml_permute dim swaps (critical fix) Performance (Qwen3-14B Q4_K_M, RX 9070 XT, 16 GB VRAM): f16/f16: 1865 pp512, 54 tg128 (baseline) q8_0/q8_0: 1694 pp512, 52 tg128 turbo4/turbo4: 1813 pp512, 49 tg128 (-3% pp, -9% tg, 72% less KV VRAM) turbo3/turbo3: 1983 pp512, 49 tg128 (+6% pp, -9% tg, 78% less KV VRAM) Usage: llama-cli -fa 1 --cache-type-k turbo4 --cache-type-v turbo4 Includes 7 CPU reference tests validating FWHT self-inverse, MSE against paper values, bitpack determinism, and dequantize sanity. Requires head_dim=128 (covers most current models including Llama, Qwen, Mistral, Gemma). Guard added to KV cache init with clear error message. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -387,6 +387,8 @@ const std::vector<ggml_type> kv_cache_types = {
|
|||||||
GGML_TYPE_IQ4_NL,
|
GGML_TYPE_IQ4_NL,
|
||||||
GGML_TYPE_Q5_0,
|
GGML_TYPE_Q5_0,
|
||||||
GGML_TYPE_Q5_1,
|
GGML_TYPE_Q5_1,
|
||||||
|
GGML_TYPE_TURBO3_0,
|
||||||
|
GGML_TYPE_TURBO4_0,
|
||||||
};
|
};
|
||||||
|
|
||||||
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
static ggml_type kv_cache_type_from_str(const std::string & s) {
|
||||||
|
|||||||
+3
-1
@@ -428,7 +428,9 @@ extern "C" {
|
|||||||
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
// GGML_TYPE_IQ4_NL_8_8 = 38,
|
||||||
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
|
||||||
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
|
||||||
GGML_TYPE_COUNT = 41,
|
GGML_TYPE_TURBO3_0 = 41, // TurboQuant 3-bit KV-cache (3.5 bpw)
|
||||||
|
GGML_TYPE_TURBO4_0 = 42, // TurboQuant 4-bit KV-cache (4.5 bpw)
|
||||||
|
GGML_TYPE_COUNT = 43,
|
||||||
};
|
};
|
||||||
|
|
||||||
// precision
|
// precision
|
||||||
|
|||||||
@@ -205,6 +205,26 @@ typedef struct {
|
|||||||
} block_nvfp4;
|
} block_nvfp4;
|
||||||
static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
|
static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
|
||||||
|
|
||||||
|
// TurboQuant 3-bit KV-cache quantization (3.5 bpw)
|
||||||
|
#define TURBO3_BLOCK_SIZE 32
|
||||||
|
#define QK_TURBO3 32
|
||||||
|
#define QR_TURBO3 2
|
||||||
|
typedef struct {
|
||||||
|
ggml_half d; // FP16 L2-norm
|
||||||
|
uint8_t qs[12]; // 32 x 3-bit packed indices
|
||||||
|
} block_turbo3_0;
|
||||||
|
static_assert(sizeof(block_turbo3_0) == 14, "wrong turbo3 block size");
|
||||||
|
|
||||||
|
// TurboQuant 4-bit KV-cache quantization (4.5 bpw)
|
||||||
|
#define TURBO4_BLOCK_SIZE 32
|
||||||
|
#define QK_TURBO4 32
|
||||||
|
#define QR_TURBO4 2
|
||||||
|
typedef struct {
|
||||||
|
ggml_half d; // FP16 L2-norm
|
||||||
|
uint8_t qs[16]; // 32 x 4-bit packed indices
|
||||||
|
} block_turbo4_0;
|
||||||
|
static_assert(sizeof(block_turbo4_0) == 18, "wrong turbo4 block size");
|
||||||
|
|
||||||
#define QK5_0 32
|
#define QK5_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
ggml_half d; // delta
|
ggml_half d; // delta
|
||||||
|
|||||||
@@ -120,7 +120,9 @@ if (CUDAToolkit_FOUND)
|
|||||||
template-instances/fattn-vec-instance-f16-f16.cu
|
template-instances/fattn-vec-instance-f16-f16.cu
|
||||||
template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
||||||
template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
||||||
template-instances/fattn-vec-instance-bf16-bf16.cu)
|
template-instances/fattn-vec-instance-bf16-bf16.cu
|
||||||
|
template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
|
||||||
|
template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
ggml_add_backend_library(ggml-cuda
|
ggml_add_backend_library(ggml-cuda
|
||||||
|
|||||||
@@ -656,6 +656,140 @@ static void dequantize_row_nvfp4_cuda(
|
|||||||
const int nb = k / QK_NVFP4;
|
const int nb = k / QK_NVFP4;
|
||||||
dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k);
|
dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant GPU bulk dequantize kernels (with FWHT)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// Each CUDA block processes one 128-element chunk (= 4 turbo blocks).
|
||||||
|
// 128 threads per block, one thread per element.
|
||||||
|
// Step 1: unpack index + centroid lookup -> shared memory
|
||||||
|
// Step 2: FWHT butterfly in shared memory (7 stages for n=128)
|
||||||
|
// Step 3: normalize by 1/sqrt(128) and scale by stored norm
|
||||||
|
// Step 4: write to output
|
||||||
|
|
||||||
|
#define TURBO_HEAD_DIM_GPU 128
|
||||||
|
#define TURBO_BLOCKS_PER_CHUNK_GPU (TURBO_HEAD_DIM_GPU / 32) // 4
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static __global__ void dequantize_block_turbo3_0_kernel(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||||
|
__shared__ float smem[TURBO_HEAD_DIM_GPU];
|
||||||
|
|
||||||
|
const int64_t chunk_idx = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x; // 0..127
|
||||||
|
|
||||||
|
const int64_t output_offset = chunk_idx * TURBO_HEAD_DIM_GPU;
|
||||||
|
if (output_offset + tid >= k) return;
|
||||||
|
|
||||||
|
// Which of the 4 blocks within this chunk does this thread belong to?
|
||||||
|
const int local_block = tid / TURBO3_BLOCK_SIZE; // 0..3
|
||||||
|
const int elem_in_block = tid % TURBO3_BLOCK_SIZE; // 0..31
|
||||||
|
|
||||||
|
const int64_t global_block_idx = chunk_idx * TURBO_BLOCKS_PER_CHUNK_GPU + local_block;
|
||||||
|
|
||||||
|
// Unpack 3-bit index and look up centroid
|
||||||
|
const block_turbo3_0 * x = (const block_turbo3_0 *)vx + global_block_idx;
|
||||||
|
const uint8_t * qs = x->qs;
|
||||||
|
|
||||||
|
int bit_off = elem_in_block * 3;
|
||||||
|
int byte_idx = bit_off / 8;
|
||||||
|
int shift = bit_off % 8;
|
||||||
|
uint16_t raw = (uint16_t)qs[byte_idx] >> shift;
|
||||||
|
if (shift > 5 && byte_idx + 1 < 12)
|
||||||
|
raw |= (uint16_t)qs[byte_idx + 1] << (8 - shift);
|
||||||
|
uint8_t idx = (uint8_t)(raw & 0x07);
|
||||||
|
|
||||||
|
smem[tid] = dc_codebook_3bit[idx];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// FWHT butterfly stages (7 stages for n=128)
|
||||||
|
for (int h = 1; h < TURBO_HEAD_DIM_GPU; h *= 2) {
|
||||||
|
if (tid < 64) { // 128/2 = 64 butterflies per stage
|
||||||
|
int group = tid / h;
|
||||||
|
int pos = tid % h;
|
||||||
|
int i = group * h * 2 + pos;
|
||||||
|
float a = smem[i];
|
||||||
|
float b = smem[i + h];
|
||||||
|
smem[i] = a + b;
|
||||||
|
smem[i + h] = a - b;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize by 1/sqrt(128) and scale by stored norm
|
||||||
|
const float fwht_scale = 0.08838834764831844f; // 1/sqrt(128)
|
||||||
|
const block_turbo3_0 * first_block = (const block_turbo3_0 *)vx + chunk_idx * TURBO_BLOCKS_PER_CHUNK_GPU;
|
||||||
|
float norm = __half2float(first_block->d);
|
||||||
|
smem[tid] *= fwht_scale * norm;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Write to output
|
||||||
|
y[output_offset + tid] = ggml_cuda_cast<dst_t>(smem[tid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_turbo3_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(k % TURBO_HEAD_DIM_GPU == 0);
|
||||||
|
const int num_chunks = (int)(k / TURBO_HEAD_DIM_GPU);
|
||||||
|
dequantize_block_turbo3_0_kernel<<<num_chunks, TURBO_HEAD_DIM_GPU, 0, stream>>>(vx, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static __global__ void dequantize_block_turbo4_0_kernel(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
|
||||||
|
__shared__ float smem[TURBO_HEAD_DIM_GPU];
|
||||||
|
|
||||||
|
const int64_t chunk_idx = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x; // 0..127
|
||||||
|
|
||||||
|
const int64_t output_offset = chunk_idx * TURBO_HEAD_DIM_GPU;
|
||||||
|
if (output_offset + tid >= k) return;
|
||||||
|
|
||||||
|
// Which of the 4 blocks within this chunk does this thread belong to?
|
||||||
|
const int local_block = tid / TURBO4_BLOCK_SIZE; // 0..3
|
||||||
|
const int elem_in_block = tid % TURBO4_BLOCK_SIZE; // 0..31
|
||||||
|
|
||||||
|
const int64_t global_block_idx = chunk_idx * TURBO_BLOCKS_PER_CHUNK_GPU + local_block;
|
||||||
|
|
||||||
|
// Unpack 4-bit index and look up centroid
|
||||||
|
const block_turbo4_0 * x = (const block_turbo4_0 *)vx + global_block_idx;
|
||||||
|
int pair_idx = elem_in_block / 2;
|
||||||
|
uint8_t packed = x->qs[pair_idx];
|
||||||
|
uint8_t idx = (elem_in_block & 1) ? ((packed >> 4) & 0x0F) : (packed & 0x0F);
|
||||||
|
|
||||||
|
smem[tid] = dc_codebook_4bit[idx];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// FWHT butterfly stages (7 stages for n=128)
|
||||||
|
for (int h = 1; h < TURBO_HEAD_DIM_GPU; h *= 2) {
|
||||||
|
if (tid < 64) { // 128/2 = 64 butterflies per stage
|
||||||
|
int group = tid / h;
|
||||||
|
int pos = tid % h;
|
||||||
|
int i = group * h * 2 + pos;
|
||||||
|
float a = smem[i];
|
||||||
|
float b = smem[i + h];
|
||||||
|
smem[i] = a + b;
|
||||||
|
smem[i + h] = a - b;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize by 1/sqrt(128) and scale by stored norm
|
||||||
|
const float fwht_scale = 0.08838834764831844f; // 1/sqrt(128)
|
||||||
|
const block_turbo4_0 * first_block = (const block_turbo4_0 *)vx + chunk_idx * TURBO_BLOCKS_PER_CHUNK_GPU;
|
||||||
|
float norm = __half2float(first_block->d);
|
||||||
|
smem[tid] *= fwht_scale * norm;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Write to output
|
||||||
|
y[output_offset + tid] = ggml_cuda_cast<dst_t>(smem[tid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename dst_t>
|
||||||
|
static void dequantize_row_turbo4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(k % TURBO_HEAD_DIM_GPU == 0);
|
||||||
|
const int num_chunks = (int)(k / TURBO_HEAD_DIM_GPU);
|
||||||
|
dequantize_block_turbo4_0_kernel<<<num_chunks, TURBO_HEAD_DIM_GPU, 0, stream>>>(vx, y, k);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static __global__ void convert_unary(
|
static __global__ void convert_unary(
|
||||||
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
|
||||||
@@ -756,6 +890,10 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
return dequantize_row_mxfp4_cuda;
|
return dequantize_row_mxfp4_cuda;
|
||||||
case GGML_TYPE_NVFP4:
|
case GGML_TYPE_NVFP4:
|
||||||
return dequantize_row_nvfp4_cuda;
|
return dequantize_row_nvfp4_cuda;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
return dequantize_row_turbo3_0_cuda;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
return dequantize_row_turbo4_0_cuda;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return convert_unary_cont_cuda<float>;
|
return convert_unary_cont_cuda<float>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
@@ -809,6 +947,10 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||||||
return dequantize_row_mxfp4_cuda;
|
return dequantize_row_mxfp4_cuda;
|
||||||
case GGML_TYPE_NVFP4:
|
case GGML_TYPE_NVFP4:
|
||||||
return dequantize_row_nvfp4_cuda;
|
return dequantize_row_nvfp4_cuda;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
return dequantize_row_turbo3_0_cuda;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
return dequantize_row_turbo4_0_cuda;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return convert_unary_cont_cuda<half>;
|
return convert_unary_cont_cuda<half>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
@@ -832,6 +974,10 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
|
|||||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return convert_unary_cuda<nv_bfloat16>;
|
return convert_unary_cuda<nv_bfloat16>;
|
||||||
default:
|
default:
|
||||||
@@ -853,6 +999,10 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
|
|||||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return convert_unary_cuda<half, nv_bfloat16>;
|
return convert_unary_cuda<half, nv_bfloat16>;
|
||||||
default:
|
default:
|
||||||
@@ -874,6 +1024,10 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
|
|||||||
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO3, QR_TURBO3, dequantize_turbo3_0>;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
return dequantize_block_cuda<QK_TURBO4, QR_TURBO4, dequantize_turbo4_0>;
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return convert_unary_cuda<nv_bfloat16, float>;
|
return convert_unary_cuda<nv_bfloat16, float>;
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -211,6 +211,93 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
|||||||
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant GPU quantize device functions
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
// Device-side codebook references (same values as in dequantize.cuh)
|
||||||
|
// These are declared extern to reference the __constant__ arrays from dequantize.cuh
|
||||||
|
// Note: we re-declare small local codebooks here to avoid linkage issues.
|
||||||
|
__device__ static const float tq_codebook_3bit_q[8] = {
|
||||||
|
-0.1883972972f, -0.1181399059f, -0.0665857641f, -0.0216044751f,
|
||||||
|
0.0216041461f, 0.0665854520f, 0.1181396281f, 0.1883970748f
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ static const float tq_codebook_4bit_q[16] = {
|
||||||
|
-0.2376389871f, -0.1808080141f, -0.1417777640f, -0.1102646123f,
|
||||||
|
-0.0828112376f, -0.0577640422f, -0.0341540905f, -0.0113168380f,
|
||||||
|
0.0112761586f, 0.0341139667f, 0.0577250301f, 0.0827738972f,
|
||||||
|
0.1102295202f, 0.1417455465f, 0.1807794468f, 0.2376153882f
|
||||||
|
};
|
||||||
|
|
||||||
|
static __device__ uint8_t tq_nearest_codebook(float val, const float *codebook, int n) {
|
||||||
|
float best_dist = fabsf(val - codebook[0]);
|
||||||
|
uint8_t best_idx = 0;
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
float dist = fabsf(val - codebook[i]);
|
||||||
|
if (dist < best_dist) {
|
||||||
|
best_dist = dist;
|
||||||
|
best_idx = (uint8_t)i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_turbo3_0_block(const float * __restrict__ x, block_turbo3_0 * __restrict__ y) {
|
||||||
|
// Compute block norm
|
||||||
|
float sum_sq = 0.0f;
|
||||||
|
for (int j = 0; j < TURBO3_BLOCK_SIZE; j++) {
|
||||||
|
sum_sq += x[j] * x[j];
|
||||||
|
}
|
||||||
|
float norm = sqrtf(sum_sq);
|
||||||
|
y->d = __float2half(norm);
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
|
||||||
|
// Quantize NORMALIZED elements to nearest 3-bit codebook entry and pack
|
||||||
|
uint8_t indices[32];
|
||||||
|
for (int j = 0; j < TURBO3_BLOCK_SIZE; j++) {
|
||||||
|
indices[j] = tq_nearest_codebook(x[j] * inv_norm, tq_codebook_3bit_q, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pack 32 x 3-bit values into 12 bytes
|
||||||
|
memset(y->qs, 0, 12);
|
||||||
|
for (int j = 0; j < 32; j++) {
|
||||||
|
int bit_off = j * 3;
|
||||||
|
int byte_idx = bit_off / 8;
|
||||||
|
int shift = bit_off % 8;
|
||||||
|
y->qs[byte_idx] |= (uint8_t)((indices[j] & 0x07) << shift);
|
||||||
|
if (shift > 5 && byte_idx + 1 < 12) {
|
||||||
|
y->qs[byte_idx + 1] |= (uint8_t)((indices[j] & 0x07) >> (8 - shift));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void quantize_f32_turbo4_0_block(const float * __restrict__ x, block_turbo4_0 * __restrict__ y) {
|
||||||
|
// Compute block norm
|
||||||
|
float sum_sq = 0.0f;
|
||||||
|
for (int j = 0; j < TURBO4_BLOCK_SIZE; j++) {
|
||||||
|
sum_sq += x[j] * x[j];
|
||||||
|
}
|
||||||
|
float norm = sqrtf(sum_sq);
|
||||||
|
y->d = __float2half(norm);
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
|
||||||
|
// Quantize NORMALIZED elements to nearest 4-bit codebook entry and pack
|
||||||
|
for (int j = 0; j < TURBO4_BLOCK_SIZE / 2; j++) {
|
||||||
|
uint8_t idx0 = tq_nearest_codebook(x[2*j] * inv_norm, tq_codebook_4bit_q, 16);
|
||||||
|
uint8_t idx1 = tq_nearest_codebook(x[2*j + 1] * inv_norm, tq_codebook_4bit_q, 16);
|
||||||
|
y->qs[j] = (idx0 & 0x0F) | ((idx1 & 0x0F) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_turbo3_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_turbo3_0_block((const float *)cxi, (block_turbo3_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void cpy_blck_f32_turbo4_0(const char * cxi, char * cdsti) {
|
||||||
|
quantize_f32_turbo4_0_block((const float *)cxi, (block_turbo4_0 *)cdsti);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename src_t, typename dst_t>
|
template<typename src_t, typename dst_t>
|
||||||
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
|
||||||
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
*(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
|
||||||
|
|||||||
@@ -75,3 +75,67 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|||||||
v.x *= d;
|
v.x *= d;
|
||||||
v.y *= d;
|
v.y *= d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant GPU dequantize device functions
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
__device__ __constant__ static float dc_codebook_3bit[8] = {
|
||||||
|
-0.1883972972f, -0.1181399059f, -0.0665857641f, -0.0216044751f,
|
||||||
|
0.0216041461f, 0.0665854520f, 0.1181396281f, 0.1883970748f
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ __constant__ static float dc_codebook_4bit[16] = {
|
||||||
|
-0.2376389871f, -0.1808080141f, -0.1417777640f, -0.1102646123f,
|
||||||
|
-0.0828112376f, -0.0577640422f, -0.0341540905f, -0.0113168380f,
|
||||||
|
0.0112761586f, 0.0341139667f, 0.0577250301f, 0.0827738972f,
|
||||||
|
0.1102295202f, 0.1417455465f, 0.1807794468f, 0.2376153882f
|
||||||
|
};
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_turbo3_0(
|
||||||
|
const void * vx, const int64_t ib, const int iqs, float2 & v)
|
||||||
|
{
|
||||||
|
const block_turbo3_0 * x = (const block_turbo3_0 *) vx + ib;
|
||||||
|
const uint8_t * qs = x->qs;
|
||||||
|
|
||||||
|
// Unpack two consecutive 3-bit indices
|
||||||
|
int elem0 = iqs * 2;
|
||||||
|
int elem1 = iqs * 2 + 1;
|
||||||
|
|
||||||
|
// Extract 3-bit value for elem0
|
||||||
|
int bit_off0 = elem0 * 3;
|
||||||
|
int byte0 = bit_off0 / 8;
|
||||||
|
int shift0 = bit_off0 % 8;
|
||||||
|
uint16_t raw0 = (uint16_t)qs[byte0] >> shift0;
|
||||||
|
if (shift0 > 5 && byte0 + 1 < 12)
|
||||||
|
raw0 |= (uint16_t)qs[byte0 + 1] << (8 - shift0);
|
||||||
|
uint8_t idx0 = (uint8_t)(raw0 & 0x07);
|
||||||
|
|
||||||
|
// Extract 3-bit value for elem1
|
||||||
|
int bit_off1 = elem1 * 3;
|
||||||
|
int byte1 = bit_off1 / 8;
|
||||||
|
int shift1 = bit_off1 % 8;
|
||||||
|
uint16_t raw1 = (uint16_t)qs[byte1] >> shift1;
|
||||||
|
if (shift1 > 5 && byte1 + 1 < 12)
|
||||||
|
raw1 |= (uint16_t)qs[byte1 + 1] << (8 - shift1);
|
||||||
|
uint8_t idx1 = (uint8_t)(raw1 & 0x07);
|
||||||
|
|
||||||
|
const float norm = __half2float(x->d);
|
||||||
|
v.x = dc_codebook_3bit[idx0] * norm;
|
||||||
|
v.y = dc_codebook_3bit[idx1] * norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dequantize_turbo4_0(
|
||||||
|
const void * vx, const int64_t ib, const int iqs, float2 & v)
|
||||||
|
{
|
||||||
|
const block_turbo4_0 * x = (const block_turbo4_0 *) vx + ib;
|
||||||
|
|
||||||
|
// 4-bit: 2 values per byte, simple nibble extraction
|
||||||
|
uint8_t packed = x->qs[iqs];
|
||||||
|
uint8_t idx0 = packed & 0x0F;
|
||||||
|
uint8_t idx1 = (packed >> 4) & 0x0F;
|
||||||
|
|
||||||
|
const float norm = __half2float(x->d);
|
||||||
|
v.x = dc_codebook_4bit[idx0] * norm;
|
||||||
|
v.y = dc_codebook_4bit[idx1] * norm;
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "convert.cuh"
|
#include "convert.cuh"
|
||||||
#include "vecdotq.cuh"
|
#include "vecdotq.cuh"
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
@@ -577,6 +578,59 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant V-cache dequantize functions for flash attention
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
template <typename T, int ne>
|
||||||
|
static __device__ __forceinline__ void dequantize_V_turbo3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||||
|
const block_turbo3_0 * x = (const block_turbo3_0 *) vx;
|
||||||
|
|
||||||
|
const int64_t ib = i0 / QK_TURBO3;
|
||||||
|
const int iqs = (int)(i0 % QK_TURBO3) / 2;
|
||||||
|
|
||||||
|
static_assert(ne % 2 == 0, "bad ne");
|
||||||
|
T * dst_t = (T *) dst;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne/2; ++l) {
|
||||||
|
float2 v;
|
||||||
|
dequantize_turbo3_0(vx, ib, iqs + l, v);
|
||||||
|
if constexpr (std::is_same_v<T, half>) {
|
||||||
|
dst_t[2*l + 0] = __float2half(v.x);
|
||||||
|
dst_t[2*l + 1] = __float2half(v.y);
|
||||||
|
} else {
|
||||||
|
dst_t[2*l + 0] = (T)v.x;
|
||||||
|
dst_t[2*l + 1] = (T)v.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int ne>
|
||||||
|
static __device__ __forceinline__ void dequantize_V_turbo4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
|
||||||
|
const block_turbo4_0 * x = (const block_turbo4_0 *) vx;
|
||||||
|
|
||||||
|
const int64_t ib = i0 / QK_TURBO4;
|
||||||
|
const int iqs = (int)(i0 % QK_TURBO4) / 2;
|
||||||
|
|
||||||
|
static_assert(ne % 2 == 0, "bad ne");
|
||||||
|
T * dst_t = (T *) dst;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int l = 0; l < ne/2; ++l) {
|
||||||
|
float2 v;
|
||||||
|
dequantize_turbo4_0(vx, ib, iqs + l, v);
|
||||||
|
if constexpr (std::is_same_v<T, half>) {
|
||||||
|
dst_t[2*l + 0] = __float2half(v.x);
|
||||||
|
dst_t[2*l + 1] = __float2half(v.y);
|
||||||
|
} else {
|
||||||
|
dst_t[2*l + 0] = (T)v.x;
|
||||||
|
dst_t[2*l + 1] = (T)v.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
}
|
||||||
|
|
||||||
template <ggml_type type_K, int D, int nthreads>
|
template <ggml_type type_K, int D, int nthreads>
|
||||||
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
||||||
if constexpr (type_K == GGML_TYPE_F16) {
|
if constexpr (type_K == GGML_TYPE_F16) {
|
||||||
@@ -593,6 +647,12 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
|
|||||||
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
|
||||||
} else if constexpr (type_K == GGML_TYPE_BF16) {
|
} else if constexpr (type_K == GGML_TYPE_BF16) {
|
||||||
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
|
return vec_dot_fattn_vec_KQ_bf16<D, nthreads>;
|
||||||
|
// TurboQuant K-cache: Phase 1 - dequantize to FP16 before attention
|
||||||
|
// (use FP16 dot product after conversion in the dispatch layer)
|
||||||
|
} else if constexpr (type_K == GGML_TYPE_TURBO3_0) {
|
||||||
|
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
||||||
|
} else if constexpr (type_K == GGML_TYPE_TURBO4_0) {
|
||||||
|
return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
|
||||||
} else {
|
} else {
|
||||||
static_assert(type_K == -1, "bad type");
|
static_assert(type_K == -1, "bad type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@@ -615,6 +675,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
|
|||||||
return dequantize_V_q8_0<T, ne>;
|
return dequantize_V_q8_0<T, ne>;
|
||||||
} else if constexpr (type_V == GGML_TYPE_BF16) {
|
} else if constexpr (type_V == GGML_TYPE_BF16) {
|
||||||
return dequantize_V_bf16<float, ne>;
|
return dequantize_V_bf16<float, ne>;
|
||||||
|
} else if constexpr (type_V == GGML_TYPE_TURBO3_0) {
|
||||||
|
return dequantize_V_turbo3_0<T, ne>;
|
||||||
|
} else if constexpr (type_V == GGML_TYPE_TURBO4_0) {
|
||||||
|
return dequantize_V_turbo4_0<T, ne>;
|
||||||
} else {
|
} else {
|
||||||
static_assert(type_V == -1, "bad type");
|
static_assert(type_V == -1, "bad type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|||||||
@@ -75,17 +75,20 @@ static __global__ void flash_attn_ext_vec(
|
|||||||
#endif // GGML_USE_HIP
|
#endif // GGML_USE_HIP
|
||||||
|
|
||||||
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
|
constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
|
||||||
constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q;
|
constexpr bool K_is_fp_like = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16);
|
||||||
constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q;
|
constexpr bool V_is_fp_like = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16);
|
||||||
|
|
||||||
|
constexpr int nthreads_KQ = K_is_fp_like ? 128 / cpy_nb : nthreads_KQ_q;
|
||||||
|
constexpr int nthreads_V = V_is_fp_like ? 128 / cpy_nb : nthreads_V_q;
|
||||||
|
|
||||||
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
|
static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
|
||||||
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
|
||||||
|
|
||||||
constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4;
|
constexpr int V_rows_per_thread = V_is_fp_like ? 2*cpy_ne : 4;
|
||||||
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
|
constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
|
||||||
|
|
||||||
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
||||||
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16;
|
constexpr bool Q_q8_1 = !K_is_fp_like;
|
||||||
#ifdef V_DOT2_F32_F16_AVAILABLE
|
#ifdef V_DOT2_F32_F16_AVAILABLE
|
||||||
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
||||||
#else
|
#else
|
||||||
@@ -598,3 +601,12 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
|
|||||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
|
||||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|
||||||
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)
|
EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16)
|
||||||
|
|
||||||
|
// TurboQuant extern declarations (homogeneous K/V only)
|
||||||
|
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
|
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
|
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
|
|
||||||
|
extern DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
|
extern DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
|
extern DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "fattn-vec.cuh"
|
#include "fattn-vec.cuh"
|
||||||
#include "fattn-wmma-f16.cuh"
|
#include "fattn-wmma-f16.cuh"
|
||||||
#include "fattn.cuh"
|
#include "fattn.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
|
|
||||||
template <int DKQ, int DV, int ncols2>
|
template <int DKQ, int DV, int ncols2>
|
||||||
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
@@ -273,6 +274,8 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t
|
|||||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16)
|
||||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16)
|
||||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16)
|
||||||
|
|
||||||
|
// TurboQuant: pre-dequantized to f16 before FA, no turbo vec cases needed
|
||||||
#else
|
#else
|
||||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
||||||
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
||||||
@@ -292,6 +295,10 @@ enum best_fattn_kernel {
|
|||||||
BEST_FATTN_KERNEL_MMA_F16 = 400,
|
BEST_FATTN_KERNEL_MMA_F16 = 400,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static bool ggml_type_is_turbo(ggml_type type) {
|
||||||
|
return type == GGML_TYPE_TURBO3_0 || type == GGML_TYPE_TURBO4_0;
|
||||||
|
}
|
||||||
|
|
||||||
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
||||||
#ifndef FLASH_ATTN_AVAILABLE
|
#ifndef FLASH_ATTN_AVAILABLE
|
||||||
GGML_UNUSED(device); GGML_UNUSED(dst);
|
GGML_UNUSED(device); GGML_UNUSED(dst);
|
||||||
@@ -353,8 +360,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifndef GGML_CUDA_FA_ALL_QUANTS
|
#ifndef GGML_CUDA_FA_ALL_QUANTS
|
||||||
if (K->type != V->type) {
|
{
|
||||||
return BEST_FATTN_KERNEL_NONE;
|
// Turbo types are pre-dequantized to f16, so treat them as f16 for type matching
|
||||||
|
const ggml_type eff_k = ggml_type_is_turbo(K->type) ? GGML_TYPE_F16 : K->type;
|
||||||
|
const ggml_type eff_v = ggml_type_is_turbo(V->type) ? GGML_TYPE_F16 : V->type;
|
||||||
|
if (eff_k != eff_v) {
|
||||||
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||||
|
|
||||||
@@ -372,6 +384,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
// Turbo types are handled via pre-dequantize to f16 before FA
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return BEST_FATTN_KERNEL_NONE;
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
@@ -485,8 +501,72 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||||||
return BEST_FATTN_KERNEL_TILE;
|
return BEST_FATTN_KERNEL_TILE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-dequantize a turbo tensor to f16, returning a stack-allocated tensor copy.
|
||||||
|
// The caller must keep pool_buf alive until after FA completes.
|
||||||
|
static ggml_tensor turbo_pre_dequantize(
|
||||||
|
const ggml_tensor * src,
|
||||||
|
ggml_cuda_pool_alloc<half> & pool_buf,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
const int64_t n_elements = ggml_nelements(src);
|
||||||
|
|
||||||
|
pool_buf.alloc(n_elements);
|
||||||
|
|
||||||
|
to_fp16_cuda_t dequant = ggml_get_to_fp16_cuda(src->type);
|
||||||
|
GGML_ASSERT(dequant != nullptr);
|
||||||
|
dequant(src->data, pool_buf.ptr, n_elements, stream);
|
||||||
|
|
||||||
|
// Scale existing strides from turbo block layout to f16 element layout.
|
||||||
|
// This preserves any permutation (e.g. ggml_permute swapping dims 1 and 2).
|
||||||
|
// The dequantized f16 data is in the same physical order as the turbo data,
|
||||||
|
// so the stride relationships must be preserved, just rescaled.
|
||||||
|
const size_t bs = ggml_blck_size(src->type);
|
||||||
|
const size_t ts = ggml_type_size(src->type);
|
||||||
|
|
||||||
|
ggml_tensor tmp = *src;
|
||||||
|
tmp.type = GGML_TYPE_F16;
|
||||||
|
tmp.data = pool_buf.ptr;
|
||||||
|
tmp.nb[0] = sizeof(half);
|
||||||
|
tmp.nb[1] = src->nb[1] * bs * sizeof(half) / ts;
|
||||||
|
tmp.nb[2] = src->nb[2] * bs * sizeof(half) / ts;
|
||||||
|
tmp.nb[3] = src->nb[3] * bs * sizeof(half) / ts;
|
||||||
|
tmp.view_src = nullptr;
|
||||||
|
tmp.view_offs = 0;
|
||||||
|
|
||||||
|
return tmp;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
|
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const bool k_is_turbo = ggml_type_is_turbo(K->type);
|
||||||
|
const bool v_is_turbo = V && ggml_type_is_turbo(V->type);
|
||||||
|
|
||||||
|
// Pre-dequantize turbo KV to f16 so standard FA kernels can handle them.
|
||||||
|
// Pool buffers must outlive the FA dispatch (RAII frees on scope exit).
|
||||||
|
ggml_cuda_pool_alloc<half> k_pool(ctx.pool());
|
||||||
|
ggml_cuda_pool_alloc<half> v_pool(ctx.pool());
|
||||||
|
ggml_tensor k_f16, v_f16;
|
||||||
|
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
// Save original src pointers
|
||||||
|
ggml_tensor * orig_k = dst->src[1];
|
||||||
|
ggml_tensor * orig_v = dst->src[2];
|
||||||
|
|
||||||
|
if (k_is_turbo) {
|
||||||
|
k_f16 = turbo_pre_dequantize(K, k_pool, stream);
|
||||||
|
dst->src[1] = &k_f16;
|
||||||
|
}
|
||||||
|
if (v_is_turbo) {
|
||||||
|
v_f16 = turbo_pre_dequantize(V, v_pool, stream);
|
||||||
|
dst->src[2] = &v_f16;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Standard FA dispatch — now sees f16 tensors
|
||||||
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
||||||
case BEST_FATTN_KERNEL_NONE:
|
case BEST_FATTN_KERNEL_NONE:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
@@ -503,6 +583,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Restore original src pointers
|
||||||
|
dst->src[1] = orig_k;
|
||||||
|
dst->src[2] = orig_v;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
||||||
|
|||||||
@@ -4842,7 +4842,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
{
|
{
|
||||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
|
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
|
||||||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL ||
|
||||||
|
op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO4_0) &&
|
||||||
op->src[0]->type == GGML_TYPE_F32 &&
|
op->src[0]->type == GGML_TYPE_F32 &&
|
||||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
||||||
} break;
|
} break;
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
#include "set-rows.cuh"
|
#include "set-rows.cuh"
|
||||||
#include "cpy-utils.cuh"
|
#include "cpy-utils.cuh"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant constants for set-rows FWHT
|
||||||
|
// ============================================================
|
||||||
|
#define TURBO_HEAD_DIM_SR 128
|
||||||
|
#define TURBO_BLOCKS_PER_CHUNK_SR (TURBO_HEAD_DIM_SR / 32) // 4
|
||||||
|
|
||||||
// Generic quantized set_rows kernel template
|
// Generic quantized set_rows kernel template
|
||||||
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
||||||
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
||||||
@@ -109,6 +117,398 @@ static void set_rows_cuda_quant(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant specialized set-rows kernel with FWHT
|
||||||
|
// ============================================================
|
||||||
|
// Each CUDA block processes one 128-element chunk.
|
||||||
|
// 128 threads per block, one thread per element in the chunk.
|
||||||
|
// Steps:
|
||||||
|
// 1. Each thread reads one float from the source row
|
||||||
|
// 2. Cooperative norm computation via shared memory reduction
|
||||||
|
// 3. Normalize the chunk
|
||||||
|
// 4. FWHT butterfly in shared memory (7 stages for n=128)
|
||||||
|
// 5. Each thread scalar-quantizes its element and packs into blocks
|
||||||
|
|
||||||
|
// Device-side codebook references for turbo quantize (same as in cpy-utils.cuh)
|
||||||
|
__device__ static const float sr_codebook_3bit[8] = {
|
||||||
|
-0.1883972972f, -0.1181399059f, -0.0665857641f, -0.0216044751f,
|
||||||
|
0.0216041461f, 0.0665854520f, 0.1181396281f, 0.1883970748f
|
||||||
|
};
|
||||||
|
|
||||||
|
__device__ static const float sr_codebook_4bit[16] = {
|
||||||
|
-0.2376389871f, -0.1808080141f, -0.1417777640f, -0.1102646123f,
|
||||||
|
-0.0828112376f, -0.0577640422f, -0.0341540905f, -0.0113168380f,
|
||||||
|
0.0112761586f, 0.0341139667f, 0.0577250301f, 0.0827738972f,
|
||||||
|
0.1102295202f, 0.1417455465f, 0.1807794468f, 0.2376153882f
|
||||||
|
};
|
||||||
|
|
||||||
|
static __device__ uint8_t sr_nearest_codebook(float val, const float *codebook, int n) {
|
||||||
|
float best_dist = fabsf(val - codebook[0]);
|
||||||
|
uint8_t best_idx = 0;
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
float dist = fabsf(val - codebook[i]);
|
||||||
|
if (dist < best_dist) {
|
||||||
|
best_dist = dist;
|
||||||
|
best_idx = (uint8_t)i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turbo3 set-rows kernel: processes 128-element chunks with FWHT
|
||||||
|
template <typename idx_t>
|
||||||
|
static __global__ void k_set_rows_turbo3(
|
||||||
|
const float * __restrict__ src0,
|
||||||
|
const idx_t * __restrict__ src1,
|
||||||
|
block_turbo3_0 * __restrict__ dst,
|
||||||
|
const int64_t ne_total_chunks,
|
||||||
|
const int64_t ne10,
|
||||||
|
const int64_t ne11,
|
||||||
|
const int64_t ne12,
|
||||||
|
const int64_t ne13,
|
||||||
|
const int64_t s01,
|
||||||
|
const int64_t s02,
|
||||||
|
const int64_t s03,
|
||||||
|
const int64_t s10,
|
||||||
|
const int64_t s11,
|
||||||
|
const int64_t s12,
|
||||||
|
const int64_t s1,
|
||||||
|
const int64_t s2,
|
||||||
|
const int64_t s3,
|
||||||
|
const int64_t ne00,
|
||||||
|
const uint3 ne00_fd,
|
||||||
|
const uint3 ne01_fd,
|
||||||
|
const uint3 ne02_fd,
|
||||||
|
const uint3 ne11_fd,
|
||||||
|
const uint3 ne12_fd) {
|
||||||
|
|
||||||
|
__shared__ float smem[TURBO_HEAD_DIM_SR];
|
||||||
|
__shared__ float reduction[TURBO_HEAD_DIM_SR];
|
||||||
|
|
||||||
|
const int64_t chunk_global = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x; // 0..127
|
||||||
|
|
||||||
|
if (chunk_global >= ne_total_chunks) return;
|
||||||
|
|
||||||
|
// Map the global chunk index to i00 (element offset within a row) + row indices
|
||||||
|
// Each chunk covers 128 elements, so the chunk's base element = chunk_global * 128
|
||||||
|
const int64_t elem_base = chunk_global * TURBO_HEAD_DIM_SR;
|
||||||
|
uint32_t tmp = (uint32_t)elem_base;
|
||||||
|
uint2 div_mod;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne00_fd);
|
||||||
|
const int64_t i00 = div_mod.y; // offset within row (multiple of 128)
|
||||||
|
tmp = div_mod.x;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne01_fd);
|
||||||
|
const int64_t i01 = div_mod.y;
|
||||||
|
tmp = div_mod.x;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne02_fd);
|
||||||
|
const int64_t i02 = div_mod.y;
|
||||||
|
const int64_t i03 = div_mod.x;
|
||||||
|
|
||||||
|
const int64_t i12 = fastmodulo((uint32_t)i03, ne12_fd);
|
||||||
|
const int64_t i11 = fastmodulo((uint32_t)i02, ne11_fd);
|
||||||
|
const int64_t i10 = i01;
|
||||||
|
|
||||||
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||||
|
|
||||||
|
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||||
|
const float val = src0_row[i00 + tid];
|
||||||
|
smem[tid] = val;
|
||||||
|
|
||||||
|
// Step 1: Compute L2 norm via parallel reduction
|
||||||
|
reduction[tid] = val * val;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int s = 64; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
reduction[tid] += reduction[tid + s];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
float norm = sqrtf(reduction[0]);
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
|
||||||
|
// Step 2: Normalize
|
||||||
|
smem[tid] *= inv_norm;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Step 3: FWHT butterfly stages (7 stages for n=128)
|
||||||
|
for (int h = 1; h < TURBO_HEAD_DIM_SR; h *= 2) {
|
||||||
|
if (tid < 64) {
|
||||||
|
int group = tid / h;
|
||||||
|
int pos = tid % h;
|
||||||
|
int i = group * h * 2 + pos;
|
||||||
|
float a = smem[i];
|
||||||
|
float b = smem[i + h];
|
||||||
|
smem[i] = a + b;
|
||||||
|
smem[i + h] = a - b;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply 1/sqrt(128) normalization
|
||||||
|
const float fwht_scale = 0.08838834764831844f;
|
||||||
|
smem[tid] *= fwht_scale;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Step 4: Scalar quantize and pack into turbo3 blocks
|
||||||
|
// Each thread quantizes its element
|
||||||
|
uint8_t my_idx = sr_nearest_codebook(smem[tid], sr_codebook_3bit, 8);
|
||||||
|
|
||||||
|
// We need to pack 32 indices per block cooperatively
|
||||||
|
// Use shared memory to collect indices, then pack
|
||||||
|
// Reuse reduction[] as uint8 storage
|
||||||
|
((uint8_t *)reduction)[tid] = my_idx;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Compute destination block pointer
|
||||||
|
// dst layout: dst_row*s1 + i02*s2 + i03*s3 gives byte offset to row start
|
||||||
|
// Then add block offset for i00
|
||||||
|
block_turbo3_0 * dst_row_ptr = (block_turbo3_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3);
|
||||||
|
const int64_t dst_block_base = i00 / TURBO3_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Only 4 threads (one per block) do the packing
|
||||||
|
if (tid < TURBO_BLOCKS_PER_CHUNK_SR) {
|
||||||
|
const int blk = tid;
|
||||||
|
block_turbo3_0 * dst_block = dst_row_ptr + dst_block_base + blk;
|
||||||
|
const uint8_t * indices = ((const uint8_t *)reduction) + blk * 32;
|
||||||
|
|
||||||
|
// Store norm
|
||||||
|
dst_block->d = __float2half(norm);
|
||||||
|
|
||||||
|
// Pack 32 x 3-bit indices into 12 bytes
|
||||||
|
memset(dst_block->qs, 0, 12);
|
||||||
|
for (int j = 0; j < 32; j++) {
|
||||||
|
int bit_off = j * 3;
|
||||||
|
int byte_pos = bit_off / 8;
|
||||||
|
int shift = bit_off % 8;
|
||||||
|
dst_block->qs[byte_pos] |= (uint8_t)((indices[j] & 0x07) << shift);
|
||||||
|
if (shift > 5 && byte_pos + 1 < 12) {
|
||||||
|
dst_block->qs[byte_pos + 1] |= (uint8_t)((indices[j] & 0x07) >> (8 - shift));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(ne10);
|
||||||
|
GGML_UNUSED(ne11);
|
||||||
|
GGML_UNUSED(ne12);
|
||||||
|
GGML_UNUSED(ne13);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turbo4 set-rows kernel: processes 128-element chunks with FWHT
|
||||||
|
template <typename idx_t>
|
||||||
|
static __global__ void k_set_rows_turbo4(
|
||||||
|
const float * __restrict__ src0,
|
||||||
|
const idx_t * __restrict__ src1,
|
||||||
|
block_turbo4_0 * __restrict__ dst,
|
||||||
|
const int64_t ne_total_chunks,
|
||||||
|
const int64_t ne10,
|
||||||
|
const int64_t ne11,
|
||||||
|
const int64_t ne12,
|
||||||
|
const int64_t ne13,
|
||||||
|
const int64_t s01,
|
||||||
|
const int64_t s02,
|
||||||
|
const int64_t s03,
|
||||||
|
const int64_t s10,
|
||||||
|
const int64_t s11,
|
||||||
|
const int64_t s12,
|
||||||
|
const int64_t s1,
|
||||||
|
const int64_t s2,
|
||||||
|
const int64_t s3,
|
||||||
|
const int64_t ne00,
|
||||||
|
const uint3 ne00_fd,
|
||||||
|
const uint3 ne01_fd,
|
||||||
|
const uint3 ne02_fd,
|
||||||
|
const uint3 ne11_fd,
|
||||||
|
const uint3 ne12_fd) {
|
||||||
|
|
||||||
|
__shared__ float smem[TURBO_HEAD_DIM_SR];
|
||||||
|
__shared__ float reduction[TURBO_HEAD_DIM_SR];
|
||||||
|
|
||||||
|
const int64_t chunk_global = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x; // 0..127
|
||||||
|
|
||||||
|
if (chunk_global >= ne_total_chunks) return;
|
||||||
|
|
||||||
|
// Map the global chunk index to i00 (element offset within a row) + row indices
|
||||||
|
const int64_t elem_base = chunk_global * TURBO_HEAD_DIM_SR;
|
||||||
|
uint32_t tmp = (uint32_t)elem_base;
|
||||||
|
uint2 div_mod;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne00_fd);
|
||||||
|
const int64_t i00 = div_mod.y;
|
||||||
|
tmp = div_mod.x;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne01_fd);
|
||||||
|
const int64_t i01 = div_mod.y;
|
||||||
|
tmp = div_mod.x;
|
||||||
|
|
||||||
|
div_mod = fast_div_modulo(tmp, ne02_fd);
|
||||||
|
const int64_t i02 = div_mod.y;
|
||||||
|
const int64_t i03 = div_mod.x;
|
||||||
|
|
||||||
|
const int64_t i12 = fastmodulo((uint32_t)i03, ne12_fd);
|
||||||
|
const int64_t i11 = fastmodulo((uint32_t)i02, ne11_fd);
|
||||||
|
const int64_t i10 = i01;
|
||||||
|
|
||||||
|
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||||
|
|
||||||
|
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
|
||||||
|
const float val = src0_row[i00 + tid];
|
||||||
|
smem[tid] = val;
|
||||||
|
|
||||||
|
// Step 1: Compute L2 norm via parallel reduction
|
||||||
|
reduction[tid] = val * val;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int s = 64; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
reduction[tid] += reduction[tid + s];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
float norm = sqrtf(reduction[0]);
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
|
||||||
|
// Step 2: Normalize
|
||||||
|
smem[tid] *= inv_norm;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Step 3: FWHT butterfly stages (7 stages for n=128)
|
||||||
|
for (int h = 1; h < TURBO_HEAD_DIM_SR; h *= 2) {
|
||||||
|
if (tid < 64) {
|
||||||
|
int group = tid / h;
|
||||||
|
int pos = tid % h;
|
||||||
|
int i = group * h * 2 + pos;
|
||||||
|
float a = smem[i];
|
||||||
|
float b = smem[i + h];
|
||||||
|
smem[i] = a + b;
|
||||||
|
smem[i + h] = a - b;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply 1/sqrt(128) normalization
|
||||||
|
const float fwht_scale = 0.08838834764831844f;
|
||||||
|
smem[tid] *= fwht_scale;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Step 4: Scalar quantize and pack into turbo4 blocks
|
||||||
|
uint8_t my_idx = sr_nearest_codebook(smem[tid], sr_codebook_4bit, 16);
|
||||||
|
|
||||||
|
// Collect indices in shared memory
|
||||||
|
((uint8_t *)reduction)[tid] = my_idx;
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Compute destination block pointer
|
||||||
|
block_turbo4_0 * dst_row_ptr = (block_turbo4_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3);
|
||||||
|
const int64_t dst_block_base = i00 / TURBO4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Only 4 threads (one per block) do the packing
|
||||||
|
if (tid < TURBO_BLOCKS_PER_CHUNK_SR) {
|
||||||
|
const int blk = tid;
|
||||||
|
block_turbo4_0 * dst_block = dst_row_ptr + dst_block_base + blk;
|
||||||
|
const uint8_t * indices = ((const uint8_t *)reduction) + blk * 32;
|
||||||
|
|
||||||
|
// Store norm
|
||||||
|
dst_block->d = __float2half(norm);
|
||||||
|
|
||||||
|
// Pack 32 x 4-bit indices into 16 bytes
|
||||||
|
for (int j = 0; j < TURBO4_BLOCK_SIZE / 2; j++) {
|
||||||
|
dst_block->qs[j] = (indices[2*j] & 0x0F) | ((indices[2*j + 1] & 0x0F) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(ne10);
|
||||||
|
GGML_UNUSED(ne11);
|
||||||
|
GGML_UNUSED(ne12);
|
||||||
|
GGML_UNUSED(ne13);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch functions for turbo set-rows
|
||||||
|
template<typename idx_t>
|
||||||
|
static void set_rows_cuda_turbo3(
|
||||||
|
const float * src0_d, const idx_t * src1_d, block_turbo3_0 * dst_d,
|
||||||
|
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||||
|
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||||
|
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
|
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||||
|
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 % TURBO_HEAD_DIM_SR == 0);
|
||||||
|
const int64_t ne_total_chunks = (ne00 * ne01 * ne02 * ne03) / TURBO_HEAD_DIM_SR;
|
||||||
|
const dim3 grid_size((int)ne_total_chunks);
|
||||||
|
const dim3 block_size(TURBO_HEAD_DIM_SR);
|
||||||
|
|
||||||
|
const int64_t s01 = nb01/sizeof(float);
|
||||||
|
const int64_t s02 = nb02/sizeof(float);
|
||||||
|
const int64_t s03 = nb03/sizeof(float);
|
||||||
|
const int64_t s10 = nb10/sizeof(idx_t);
|
||||||
|
const int64_t s11 = nb11/sizeof(idx_t);
|
||||||
|
const int64_t s12 = nb12/sizeof(idx_t);
|
||||||
|
const int64_t s1 = nb1;
|
||||||
|
const int64_t s2 = nb2;
|
||||||
|
const int64_t s3 = nb3;
|
||||||
|
|
||||||
|
if (ne_total_chunks > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||||
|
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||||
|
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||||
|
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||||
|
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||||
|
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||||
|
|
||||||
|
k_set_rows_turbo3<idx_t><<<grid_size, block_size, 0, stream>>>(
|
||||||
|
src0_d, src1_d, dst_d, ne_total_chunks, ne10, ne11, ne12, ne13,
|
||||||
|
s01, s02, s03, s10, s11, s12, s1, s2, s3,
|
||||||
|
ne00, ne00_fd, ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename idx_t>
|
||||||
|
static void set_rows_cuda_turbo4(
|
||||||
|
const float * src0_d, const idx_t * src1_d, block_turbo4_0 * dst_d,
|
||||||
|
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||||
|
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||||
|
const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
|
const size_t nb10, const size_t nb11, const size_t nb12,
|
||||||
|
const size_t nb1, const size_t nb2, const size_t nb3,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 % TURBO_HEAD_DIM_SR == 0);
|
||||||
|
const int64_t ne_total_chunks = (ne00 * ne01 * ne02 * ne03) / TURBO_HEAD_DIM_SR;
|
||||||
|
const dim3 grid_size((int)ne_total_chunks);
|
||||||
|
const dim3 block_size(TURBO_HEAD_DIM_SR);
|
||||||
|
|
||||||
|
const int64_t s01 = nb01/sizeof(float);
|
||||||
|
const int64_t s02 = nb02/sizeof(float);
|
||||||
|
const int64_t s03 = nb03/sizeof(float);
|
||||||
|
const int64_t s10 = nb10/sizeof(idx_t);
|
||||||
|
const int64_t s11 = nb11/sizeof(idx_t);
|
||||||
|
const int64_t s12 = nb12/sizeof(idx_t);
|
||||||
|
const int64_t s1 = nb1;
|
||||||
|
const int64_t s2 = nb2;
|
||||||
|
const int64_t s3 = nb3;
|
||||||
|
|
||||||
|
if (ne_total_chunks > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||||
|
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||||
|
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||||
|
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||||
|
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||||
|
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||||
|
|
||||||
|
k_set_rows_turbo4<idx_t><<<grid_size, block_size, 0, stream>>>(
|
||||||
|
src0_d, src1_d, dst_d, ne_total_chunks, ne10, ne11, ne12, ne13,
|
||||||
|
s01, s02, s03, s10, s11, s12, s1, s2, s3,
|
||||||
|
ne00, ne00_fd, ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename src_t, typename idx_t, typename dst_t>
|
template <typename src_t, typename idx_t, typename dst_t>
|
||||||
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
||||||
const idx_t * __restrict__ src1,
|
const idx_t * __restrict__ src1,
|
||||||
@@ -309,6 +709,28 @@ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * s
|
|||||||
nb1, nb2, nb3,
|
nb1, nb2, nb3,
|
||||||
stream
|
stream
|
||||||
);
|
);
|
||||||
|
} else if (dst->type == GGML_TYPE_TURBO3_0) {
|
||||||
|
// FWHT-aware 128-thread kernels for correct TurboQuant encoding
|
||||||
|
set_rows_cuda_turbo3<idx_t>(
|
||||||
|
src0_d, src1_d, (block_turbo3_0*)dst->data,
|
||||||
|
ne00, ne01, ne02, ne03,
|
||||||
|
ne10, ne11, ne12, ne13,
|
||||||
|
nb01, nb02, nb03,
|
||||||
|
nb10, nb11, nb12,
|
||||||
|
nb1, nb2, nb3,
|
||||||
|
stream
|
||||||
|
);
|
||||||
|
} else if (dst->type == GGML_TYPE_TURBO4_0) {
|
||||||
|
// FWHT-aware 128-thread kernels for correct TurboQuant encoding
|
||||||
|
set_rows_cuda_turbo4<idx_t>(
|
||||||
|
src0_d, src1_d, (block_turbo4_0*)dst->data,
|
||||||
|
ne00, ne01, ne02, ne03,
|
||||||
|
ne10, ne11, ne12, ne13,
|
||||||
|
nb01, nb02, nb03,
|
||||||
|
nb10, nb11, nb12,
|
||||||
|
nb1, nb2, nb3,
|
||||||
|
stream
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
|
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
// TurboQuant flash attention template instances
|
||||||
|
|
||||||
|
#include "../fattn-vec.cuh"
|
||||||
|
|
||||||
|
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
|
DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
|
DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0);
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
// TurboQuant flash attention template instances
|
||||||
|
|
||||||
|
#include "../fattn-vec.cuh"
|
||||||
|
|
||||||
|
DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
|
DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
|
DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0);
|
||||||
@@ -75,7 +75,9 @@ else()
|
|||||||
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
|
../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
|
||||||
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
|
||||||
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
|
||||||
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu)
|
../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu
|
||||||
|
../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
|
||||||
|
../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
ggml_add_backend_library(ggml-hip
|
ggml_add_backend_library(ggml-hip
|
||||||
|
|||||||
@@ -494,6 +494,342 @@ void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_REST
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant codebook data (Lloyd-Max optimal for d=128 after WHT)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static const float turbo_codebook_3bit[8] = {
|
||||||
|
-0.1883972972f, -0.1181399059f, -0.0665857641f, -0.0216044751f,
|
||||||
|
0.0216041461f, 0.0665854520f, 0.1181396281f, 0.1883970748f
|
||||||
|
};
|
||||||
|
|
||||||
|
static const float turbo_codebook_4bit[16] = {
|
||||||
|
-0.2376389871f, -0.1808080141f, -0.1417777640f, -0.1102646123f,
|
||||||
|
-0.0828112376f, -0.0577640422f, -0.0341540905f, -0.0113168380f,
|
||||||
|
0.0112761586f, 0.0341139667f, 0.0577250301f, 0.0827738972f,
|
||||||
|
0.1102295202f, 0.1417455465f, 0.1807794468f, 0.2376153882f
|
||||||
|
};
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant helper: pack/unpack bit-packed indices
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static void turbo_pack3(const uint8_t *indices, uint8_t *out) {
|
||||||
|
// Pack 32 x 3-bit values into 12 bytes (96 bits)
|
||||||
|
memset(out, 0, 12);
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
int bit_off = i * 3;
|
||||||
|
int byte_idx = bit_off / 8;
|
||||||
|
int shift = bit_off % 8;
|
||||||
|
out[byte_idx] |= (uint8_t)((indices[i] & 0x07) << shift);
|
||||||
|
if (shift > 5 && byte_idx + 1 < 12) {
|
||||||
|
out[byte_idx + 1] |= (uint8_t)((indices[i] & 0x07) >> (8 - shift));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void turbo_unpack3(const uint8_t *packed, uint8_t *indices) {
|
||||||
|
// Unpack 12 bytes into 32 x 3-bit values
|
||||||
|
for (int i = 0; i < 32; i++) {
|
||||||
|
int bit_off = i * 3;
|
||||||
|
int byte_idx = bit_off / 8;
|
||||||
|
int shift = bit_off % 8;
|
||||||
|
uint16_t raw = (uint16_t)packed[byte_idx] >> shift;
|
||||||
|
if (shift > 5 && byte_idx + 1 < 12) {
|
||||||
|
raw |= (uint16_t)packed[byte_idx + 1] << (8 - shift);
|
||||||
|
}
|
||||||
|
indices[i] = (uint8_t)(raw & 0x07);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void turbo_pack4(const uint8_t *indices, uint8_t *out) {
|
||||||
|
// Pack 32 x 4-bit values into 16 bytes
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
out[i] = (indices[2*i] & 0x0F) | ((indices[2*i + 1] & 0x0F) << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void turbo_unpack4(const uint8_t *packed, uint8_t *indices) {
|
||||||
|
// Unpack 16 bytes into 32 x 4-bit values
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
indices[2*i] = packed[i] & 0x0F;
|
||||||
|
indices[2*i + 1] = (packed[i] >> 4) & 0x0F;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint8_t turbo_quantize_scalar(float val, const float *codebook, int n_codes) {
|
||||||
|
// Find nearest codebook entry (linear scan - codebook is sorted)
|
||||||
|
float best_dist = fabsf(val - codebook[0]);
|
||||||
|
uint8_t best_idx = 0;
|
||||||
|
for (int i = 1; i < n_codes; i++) {
|
||||||
|
float dist = fabsf(val - codebook[i]);
|
||||||
|
if (dist < best_dist) {
|
||||||
|
best_dist = dist;
|
||||||
|
best_idx = (uint8_t)i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant FWHT (Fast Walsh-Hadamard Transform)
|
||||||
|
// Self-inverse with 1/sqrt(n) normalization.
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
#define TURBO_HEAD_DIM 128
|
||||||
|
#define TURBO_BLOCKS_PER_CHUNK (TURBO_HEAD_DIM / 32) // 4 blocks of 32 = 128
|
||||||
|
|
||||||
|
static void turbo_fwht_f32(float *x, int n) {
|
||||||
|
// Butterfly sums
|
||||||
|
for (int h = 1; h < n; h *= 2) {
|
||||||
|
for (int i = 0; i < n; i += h * 2) {
|
||||||
|
for (int j = i; j < i + h; j++) {
|
||||||
|
float a = x[j];
|
||||||
|
float b = x[j + h];
|
||||||
|
x[j] = a + b;
|
||||||
|
x[j + h] = a - b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Normalize by 1/sqrt(n)
|
||||||
|
float scale = 1.0f / sqrtf((float)n);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
x[i] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant TURBO3_0 (3-bit, 3.5 bpw)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
void quantize_row_turbo3_0_ref(const float * GGML_RESTRICT src, block_turbo3_0 * GGML_RESTRICT dst, int64_t k) {
|
||||||
|
assert(k % TURBO3_BLOCK_SIZE == 0);
|
||||||
|
assert(k % TURBO_HEAD_DIM == 0);
|
||||||
|
|
||||||
|
float tmp[TURBO_HEAD_DIM];
|
||||||
|
int64_t blocks_done = 0;
|
||||||
|
|
||||||
|
for (int64_t offset = 0; offset < k; offset += TURBO_HEAD_DIM) {
|
||||||
|
// Step 1: Compute L2 norm of this head_dim chunk
|
||||||
|
float sum_sq = 0.0f;
|
||||||
|
for (int i = 0; i < TURBO_HEAD_DIM; i++) {
|
||||||
|
sum_sq += src[offset + i] * src[offset + i];
|
||||||
|
}
|
||||||
|
float norm = sqrtf(sum_sq);
|
||||||
|
|
||||||
|
// Step 2: Normalize the chunk
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
for (int i = 0; i < TURBO_HEAD_DIM; i++) {
|
||||||
|
tmp[i] = src[offset + i] * inv_norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Apply FWHT rotation
|
||||||
|
turbo_fwht_f32(tmp, TURBO_HEAD_DIM);
|
||||||
|
|
||||||
|
// Step 4: Scalar quantize + pack, one block at a time
|
||||||
|
for (int blk = 0; blk < TURBO_BLOCKS_PER_CHUNK; blk++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
for (int i = 0; i < TURBO3_BLOCK_SIZE; i++) {
|
||||||
|
indices[i] = turbo_quantize_scalar(tmp[blk * TURBO3_BLOCK_SIZE + i], turbo_codebook_3bit, 8);
|
||||||
|
}
|
||||||
|
// Store same norm in every block of this chunk
|
||||||
|
dst[blocks_done].d = GGML_FP32_TO_FP16(norm);
|
||||||
|
turbo_pack3(indices, dst[blocks_done].qs);
|
||||||
|
blocks_done++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void quantize_row_turbo3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t k) {
|
||||||
|
quantize_row_turbo3_0_ref(src, (block_turbo3_0 *)dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequantize_row_turbo3_0(const block_turbo3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||||
|
assert(k % TURBO3_BLOCK_SIZE == 0);
|
||||||
|
const int64_t num_blocks = k / TURBO3_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Pass 1: Unpack all blocks and look up centroids
|
||||||
|
for (int64_t b = 0; b < num_blocks; b++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
turbo_unpack3(x[b].qs, indices);
|
||||||
|
for (int i = 0; i < TURBO3_BLOCK_SIZE; i++) {
|
||||||
|
y[b * TURBO3_BLOCK_SIZE + i] = turbo_codebook_3bit[indices[i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 2: Inverse FWHT per head_dim chunk, then scale by norm
|
||||||
|
for (int64_t offset = 0; offset < k; offset += TURBO_HEAD_DIM) {
|
||||||
|
int chunk = TURBO_HEAD_DIM;
|
||||||
|
if (offset + chunk > k) chunk = (int)(k - offset);
|
||||||
|
|
||||||
|
// Inverse FWHT (self-inverse with 1/sqrt(n) normalization)
|
||||||
|
turbo_fwht_f32(y + offset, chunk);
|
||||||
|
|
||||||
|
// Read norm from the first block of this chunk
|
||||||
|
float norm = GGML_FP16_TO_FP32(x[offset / TURBO3_BLOCK_SIZE].d);
|
||||||
|
for (int i = 0; i < chunk; i++) {
|
||||||
|
y[offset + i] *= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant TURBO4_0 (4-bit, 4.5 bpw)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT src, block_turbo4_0 * GGML_RESTRICT dst, int64_t k) {
|
||||||
|
assert(k % TURBO4_BLOCK_SIZE == 0);
|
||||||
|
assert(k % TURBO_HEAD_DIM == 0);
|
||||||
|
|
||||||
|
float tmp[TURBO_HEAD_DIM];
|
||||||
|
int64_t blocks_done = 0;
|
||||||
|
|
||||||
|
for (int64_t offset = 0; offset < k; offset += TURBO_HEAD_DIM) {
|
||||||
|
// Step 1: Compute L2 norm of this head_dim chunk
|
||||||
|
float sum_sq = 0.0f;
|
||||||
|
for (int i = 0; i < TURBO_HEAD_DIM; i++) {
|
||||||
|
sum_sq += src[offset + i] * src[offset + i];
|
||||||
|
}
|
||||||
|
float norm = sqrtf(sum_sq);
|
||||||
|
|
||||||
|
// Step 2: Normalize the chunk
|
||||||
|
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
|
||||||
|
for (int i = 0; i < TURBO_HEAD_DIM; i++) {
|
||||||
|
tmp[i] = src[offset + i] * inv_norm;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Apply FWHT rotation
|
||||||
|
turbo_fwht_f32(tmp, TURBO_HEAD_DIM);
|
||||||
|
|
||||||
|
// Step 4: Scalar quantize + pack, one block at a time
|
||||||
|
for (int blk = 0; blk < TURBO_BLOCKS_PER_CHUNK; blk++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
for (int i = 0; i < TURBO4_BLOCK_SIZE; i++) {
|
||||||
|
indices[i] = turbo_quantize_scalar(tmp[blk * TURBO4_BLOCK_SIZE + i], turbo_codebook_4bit, 16);
|
||||||
|
}
|
||||||
|
// Store same norm in every block of this chunk
|
||||||
|
dst[blocks_done].d = GGML_FP32_TO_FP16(norm);
|
||||||
|
turbo_pack4(indices, dst[blocks_done].qs);
|
||||||
|
blocks_done++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void quantize_row_turbo4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t k) {
|
||||||
|
quantize_row_turbo4_0_ref(src, (block_turbo4_0 *)dst, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||||
|
assert(k % TURBO4_BLOCK_SIZE == 0);
|
||||||
|
const int64_t num_blocks = k / TURBO4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Pass 1: Unpack all blocks and look up centroids
|
||||||
|
for (int64_t b = 0; b < num_blocks; b++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
turbo_unpack4(x[b].qs, indices);
|
||||||
|
for (int i = 0; i < TURBO4_BLOCK_SIZE; i++) {
|
||||||
|
y[b * TURBO4_BLOCK_SIZE + i] = turbo_codebook_4bit[indices[i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass 2: Inverse FWHT per head_dim chunk, then scale by norm
|
||||||
|
for (int64_t offset = 0; offset < k; offset += TURBO_HEAD_DIM) {
|
||||||
|
int chunk = TURBO_HEAD_DIM;
|
||||||
|
if (offset + chunk > k) chunk = (int)(k - offset);
|
||||||
|
|
||||||
|
// Inverse FWHT (self-inverse with 1/sqrt(n) normalization)
|
||||||
|
turbo_fwht_f32(y + offset, chunk);
|
||||||
|
|
||||||
|
// Read norm from the first block of this chunk
|
||||||
|
float norm = GGML_FP16_TO_FP32(x[offset / TURBO4_BLOCK_SIZE].d);
|
||||||
|
for (int i = 0; i < chunk; i++) {
|
||||||
|
y[offset + i] *= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// TurboQuant vec_dot (for flash attention compatibility)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
void ggml_vec_dot_turbo3_0(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx, size_t bx,
|
||||||
|
const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
|
const block_turbo3_0 *x = (const block_turbo3_0 *)vx;
|
||||||
|
const float *y = (const float *)vy;
|
||||||
|
|
||||||
|
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);
|
||||||
|
|
||||||
|
// Dequantize x into temp buffer (includes inverse FWHT + norm scaling),
|
||||||
|
// then compute dot product with y.
|
||||||
|
float tmp[TURBO_HEAD_DIM];
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t offset = 0; offset < n; offset += TURBO_HEAD_DIM) {
|
||||||
|
int chunk = TURBO_HEAD_DIM;
|
||||||
|
if (offset + chunk > n) chunk = (int)(n - offset);
|
||||||
|
int64_t base_block = offset / TURBO3_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Unpack + centroid lookup for this chunk
|
||||||
|
for (int blk = 0; blk < chunk / TURBO3_BLOCK_SIZE; blk++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
turbo_unpack3(x[base_block + blk].qs, indices);
|
||||||
|
for (int i = 0; i < TURBO3_BLOCK_SIZE; i++) {
|
||||||
|
tmp[blk * TURBO3_BLOCK_SIZE + i] = turbo_codebook_3bit[indices[i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inverse FWHT
|
||||||
|
turbo_fwht_f32(tmp, chunk);
|
||||||
|
|
||||||
|
// Scale by norm and accumulate dot product
|
||||||
|
float norm = GGML_FP16_TO_FP32(x[base_block].d);
|
||||||
|
for (int i = 0; i < chunk; i++) {
|
||||||
|
sum += tmp[i] * norm * y[offset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*s = sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_vec_dot_turbo4_0(int n, float * GGML_RESTRICT s, size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx, size_t bx,
|
||||||
|
const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||||
|
const block_turbo4_0 *x = (const block_turbo4_0 *)vx;
|
||||||
|
const float *y = (const float *)vy;
|
||||||
|
|
||||||
|
GGML_UNUSED(bs); GGML_UNUSED(bx); GGML_UNUSED(by); GGML_UNUSED(nrc);
|
||||||
|
|
||||||
|
// Dequantize x into temp buffer (includes inverse FWHT + norm scaling),
|
||||||
|
// then compute dot product with y.
|
||||||
|
float tmp[TURBO_HEAD_DIM];
|
||||||
|
float sum = 0.0f;
|
||||||
|
|
||||||
|
for (int64_t offset = 0; offset < n; offset += TURBO_HEAD_DIM) {
|
||||||
|
int chunk = TURBO_HEAD_DIM;
|
||||||
|
if (offset + chunk > n) chunk = (int)(n - offset);
|
||||||
|
int64_t base_block = offset / TURBO4_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// Unpack + centroid lookup for this chunk
|
||||||
|
for (int blk = 0; blk < chunk / TURBO4_BLOCK_SIZE; blk++) {
|
||||||
|
uint8_t indices[32];
|
||||||
|
turbo_unpack4(x[base_block + blk].qs, indices);
|
||||||
|
for (int i = 0; i < TURBO4_BLOCK_SIZE; i++) {
|
||||||
|
tmp[blk * TURBO4_BLOCK_SIZE + i] = turbo_codebook_4bit[indices[i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inverse FWHT
|
||||||
|
turbo_fwht_f32(tmp, chunk);
|
||||||
|
|
||||||
|
// Scale by norm and accumulate dot product
|
||||||
|
float norm = GGML_FP16_TO_FP32(x[base_block].d);
|
||||||
|
for (int i = 0; i < chunk; i++) {
|
||||||
|
sum += tmp[i] * norm * y[offset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*s = sum;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// 2-6 bit quantization in super-blocks
|
// 2-6 bit quantization in super-blocks
|
||||||
//
|
//
|
||||||
@@ -5353,6 +5689,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|||||||
{
|
{
|
||||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_TYPE_TURBO3_0:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_turbo3_0, data, nb);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_TURBO4_0:
|
||||||
|
{
|
||||||
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_turbo4_0, data, nb);
|
||||||
|
} break;
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
{
|
{
|
||||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 *
|
|||||||
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
|
GGML_API void quantize_row_turbo3_0_ref(const float * GGML_RESTRICT x, block_turbo3_0 * GGML_RESTRICT y, int64_t k);
|
||||||
|
GGML_API void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||||
@@ -51,6 +54,9 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
|
|||||||
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
|
GGML_API void dequantize_row_turbo3_0(const block_turbo3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
GGML_API void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
|||||||
@@ -904,6 +904,22 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
|
|||||||
.type_size = 0,
|
.type_size = 0,
|
||||||
.is_quantized = false,
|
.is_quantized = false,
|
||||||
},
|
},
|
||||||
|
[GGML_TYPE_TURBO3_0] = {
|
||||||
|
.type_name = "turbo3",
|
||||||
|
.blck_size = TURBO3_BLOCK_SIZE,
|
||||||
|
.type_size = sizeof(block_turbo3_0),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_turbo3_0,
|
||||||
|
.from_float_ref = (ggml_from_float_t) quantize_row_turbo3_0_ref,
|
||||||
|
},
|
||||||
|
[GGML_TYPE_TURBO4_0] = {
|
||||||
|
.type_name = "turbo4",
|
||||||
|
.blck_size = TURBO4_BLOCK_SIZE,
|
||||||
|
.type_size = sizeof(block_turbo4_0),
|
||||||
|
.is_quantized = true,
|
||||||
|
.to_float = (ggml_to_float_t) dequantize_row_turbo4_0,
|
||||||
|
.from_float_ref = (ggml_from_float_t) quantize_row_turbo4_0_ref,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
|
const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
|
||||||
|
|||||||
@@ -132,6 +132,22 @@ llama_kv_cache::llama_kv_cache(
|
|||||||
throw std::runtime_error("failed to create ggml context for kv cache");
|
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TurboQuant requires head_dim=128 for the FWHT transform
|
||||||
|
if (type_k == GGML_TYPE_TURBO3_0 || type_k == GGML_TYPE_TURBO4_0) {
|
||||||
|
const uint32_t n_embd_head_k = hparams.n_embd_head_k(il);
|
||||||
|
if (n_embd_head_k != 128) {
|
||||||
|
LLAMA_LOG_ERROR("%s: TurboQuant requires head_dim=128, got %d (layer %d)\n", __func__, n_embd_head_k, il);
|
||||||
|
throw std::runtime_error("turbo types require head_dim=128");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (type_v == GGML_TYPE_TURBO3_0 || type_v == GGML_TYPE_TURBO4_0) {
|
||||||
|
const uint32_t n_embd_head_v = hparams.n_embd_head_v(il);
|
||||||
|
if (n_embd_head_v != 128) {
|
||||||
|
LLAMA_LOG_ERROR("%s: TurboQuant requires head_dim=128, got %d (layer %d)\n", __func__, n_embd_head_v, il);
|
||||||
|
throw std::runtime_error("turbo types require head_dim=128");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const bool has_k = true;
|
const bool has_k = true;
|
||||||
const bool has_v = !is_mla;
|
const bool has_v = !is_mla;
|
||||||
|
|
||||||
|
|||||||
@@ -254,6 +254,10 @@ if (NOT GGML_BACKEND_DL)
|
|||||||
llama_build_and_test(test-quantize-fns.cpp)
|
llama_build_and_test(test-quantize-fns.cpp)
|
||||||
llama_build_and_test(test-quantize-perf.cpp)
|
llama_build_and_test(test-quantize-perf.cpp)
|
||||||
llama_build_and_test(test-rope.cpp)
|
llama_build_and_test(test-rope.cpp)
|
||||||
|
|
||||||
|
# TurboQuant CPU reference tests (FWHT, MSE, bitpack)
|
||||||
|
llama_build(test-turboquant.cpp)
|
||||||
|
llama_test(test-turboquant)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# libmtmd
|
# libmtmd
|
||||||
|
|||||||
@@ -0,0 +1,217 @@
|
|||||||
|
// TurboQuant CPU reference tests: FWHT self-inverse, roundtrip MSE, bit-packing
|
||||||
|
//
|
||||||
|
// Validates that the quantize/dequantize pipeline in ggml-quants.c produces
|
||||||
|
// MSE*d values consistent with the paper (Zandieh et al., ICLR 2026).
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#undef NDEBUG
|
||||||
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
static constexpr int HEAD_DIM = 128;
|
||||||
|
static constexpr int BLOCK_SIZE = 32;
|
||||||
|
static constexpr int BLOCKS_PER_CHUNK = HEAD_DIM / BLOCK_SIZE;
|
||||||
|
static constexpr int N_VECTORS = 10000;
|
||||||
|
static constexpr int N_ELEMENTS = N_VECTORS * HEAD_DIM;
|
||||||
|
|
||||||
|
// Expected MSE*d ranges (paper: TQ3 ~0.034, TQ4 ~0.009 for d=128)
|
||||||
|
static constexpr float TQ3_MSE_D_MIN = 0.025f;
|
||||||
|
static constexpr float TQ3_MSE_D_MAX = 0.045f;
|
||||||
|
static constexpr float TQ4_MSE_D_MIN = 0.005f;
|
||||||
|
static constexpr float TQ4_MSE_D_MAX = 0.015f;
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// FWHT reference (must match ggml-quants.c)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static void fwht_f32(float * x, int n) {
|
||||||
|
for (int h = 1; h < n; h *= 2) {
|
||||||
|
for (int i = 0; i < n; i += h * 2) {
|
||||||
|
for (int j = i; j < i + h; j++) {
|
||||||
|
float a = x[j];
|
||||||
|
float b = x[j + h];
|
||||||
|
x[j] = a + b;
|
||||||
|
x[j + h] = a - b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float scale = 1.0f / sqrtf((float)n);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
x[i] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Test 1: FWHT self-inverse (FWHT(FWHT(x)) == x)
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static int test_fwht_self_inverse(void) {
|
||||||
|
printf(" FWHT self-inverse (d=%d)... ", HEAD_DIM);
|
||||||
|
|
||||||
|
float orig[HEAD_DIM];
|
||||||
|
float work[HEAD_DIM];
|
||||||
|
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
orig[i] = sinf((float)(i + 1) * 0.7f) * 2.0f;
|
||||||
|
}
|
||||||
|
memcpy(work, orig, sizeof(orig));
|
||||||
|
|
||||||
|
fwht_f32(work, HEAD_DIM);
|
||||||
|
fwht_f32(work, HEAD_DIM);
|
||||||
|
|
||||||
|
float max_err = 0.0f;
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
float err = fabsf(work[i] - orig[i]);
|
||||||
|
if (err > max_err) { max_err = err; }
|
||||||
|
}
|
||||||
|
|
||||||
|
bool pass = max_err < 1e-5f;
|
||||||
|
printf("max_err=%.2e %s\n", max_err, pass ? "ok" : "FAILED");
|
||||||
|
return pass ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Test 2 & 3: Roundtrip MSE for TQ3 and TQ4
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static void generate_random_vectors(float * dst, int n_elements, unsigned int seed) {
|
||||||
|
unsigned int state = seed;
|
||||||
|
for (int i = 0; i < n_elements; i++) {
|
||||||
|
state = state * 1664525u + 1013904223u;
|
||||||
|
dst[i] = ((float)(state >> 8) / (float)(1 << 24)) * 2.0f - 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int test_roundtrip_mse(ggml_type type, float mse_d_min, float mse_d_max) {
|
||||||
|
const char * name = ggml_type_name(type);
|
||||||
|
printf(" %s roundtrip MSE*d (n=%d, d=%d)... ", name, N_VECTORS, HEAD_DIM);
|
||||||
|
|
||||||
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
||||||
|
assert(traits->from_float_ref != nullptr);
|
||||||
|
assert(traits->to_float != nullptr);
|
||||||
|
|
||||||
|
std::vector<float> src(N_ELEMENTS);
|
||||||
|
std::vector<float> dst(N_ELEMENTS);
|
||||||
|
size_t quant_size = (size_t)N_ELEMENTS / BLOCK_SIZE * ggml_type_size(type);
|
||||||
|
std::vector<uint8_t> quant(quant_size);
|
||||||
|
|
||||||
|
generate_random_vectors(src.data(), N_ELEMENTS, 42);
|
||||||
|
|
||||||
|
traits->from_float_ref(src.data(), quant.data(), N_ELEMENTS);
|
||||||
|
traits->to_float(quant.data(), dst.data(), N_ELEMENTS);
|
||||||
|
|
||||||
|
// MSE*d = E[ ||x - x̃||² / ||x||² ] (normalized reconstruction error)
|
||||||
|
double total_nmse = 0.0;
|
||||||
|
for (int v = 0; v < N_VECTORS; v++) {
|
||||||
|
double err_sq = 0.0;
|
||||||
|
double norm_sq = 0.0;
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
double diff = (double)src[v * HEAD_DIM + i] - (double)dst[v * HEAD_DIM + i];
|
||||||
|
err_sq += diff * diff;
|
||||||
|
norm_sq += (double)src[v * HEAD_DIM + i] * (double)src[v * HEAD_DIM + i];
|
||||||
|
}
|
||||||
|
if (norm_sq > 1e-20) {
|
||||||
|
total_nmse += err_sq / norm_sq;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
float mse_d = (float)(total_nmse / N_VECTORS);
|
||||||
|
|
||||||
|
bool pass = mse_d >= mse_d_min && mse_d <= mse_d_max;
|
||||||
|
printf("MSE*d=%.4f [%.3f..%.3f] %s\n", mse_d, mse_d_min, mse_d_max, pass ? "ok" : "FAILED");
|
||||||
|
return pass ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Test 4: Bit-pack determinism and sanity
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
static int test_bitpack_deterministic(ggml_type type) {
|
||||||
|
const char * name = ggml_type_name(type);
|
||||||
|
printf(" %s pack determinism... ", name);
|
||||||
|
|
||||||
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
||||||
|
|
||||||
|
float src[HEAD_DIM];
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
src[i] = cosf((float)i * 0.31415f);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t qsize = BLOCKS_PER_CHUNK * ggml_type_size(type);
|
||||||
|
std::vector<uint8_t> q1(qsize);
|
||||||
|
std::vector<uint8_t> q2(qsize);
|
||||||
|
|
||||||
|
traits->from_float_ref(src, q1.data(), HEAD_DIM);
|
||||||
|
traits->from_float_ref(src, q2.data(), HEAD_DIM);
|
||||||
|
|
||||||
|
bool pass = memcmp(q1.data(), q2.data(), qsize) == 0;
|
||||||
|
printf("%s\n", pass ? "ok" : "FAILED (non-deterministic)");
|
||||||
|
return pass ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int test_bitpack_sanity(ggml_type type) {
|
||||||
|
const char * name = ggml_type_name(type);
|
||||||
|
printf(" %s dequantize sanity... ", name);
|
||||||
|
|
||||||
|
const ggml_type_traits * traits = ggml_get_type_traits(type);
|
||||||
|
|
||||||
|
float src[HEAD_DIM];
|
||||||
|
float dst[HEAD_DIM];
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
src[i] = cosf((float)i * 0.31415f);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t qsize = BLOCKS_PER_CHUNK * ggml_type_size(type);
|
||||||
|
std::vector<uint8_t> q(qsize);
|
||||||
|
|
||||||
|
traits->from_float_ref(src, q.data(), HEAD_DIM);
|
||||||
|
traits->to_float(q.data(), dst, HEAD_DIM);
|
||||||
|
|
||||||
|
bool all_finite = true;
|
||||||
|
bool any_nonzero = false;
|
||||||
|
for (int i = 0; i < HEAD_DIM; i++) {
|
||||||
|
if (!std::isfinite(dst[i])) { all_finite = false; }
|
||||||
|
if (fabsf(dst[i]) > 1e-10f) { any_nonzero = true; }
|
||||||
|
}
|
||||||
|
|
||||||
|
bool pass = all_finite && any_nonzero;
|
||||||
|
printf("finite=%s nonzero=%s %s\n",
|
||||||
|
all_finite ? "yes" : "NO",
|
||||||
|
any_nonzero ? "yes" : "NO",
|
||||||
|
pass ? "ok" : "FAILED");
|
||||||
|
return pass ? 0 : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Main
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
int main(void) {
|
||||||
|
printf("TurboQuant CPU reference tests\n");
|
||||||
|
printf("==============================\n\n");
|
||||||
|
|
||||||
|
int n_fail = 0;
|
||||||
|
|
||||||
|
printf("Test 1: FWHT self-inverse\n");
|
||||||
|
n_fail += test_fwht_self_inverse();
|
||||||
|
|
||||||
|
printf("\nTest 2: TQ3 roundtrip MSE\n");
|
||||||
|
n_fail += test_roundtrip_mse(GGML_TYPE_TURBO3_0, TQ3_MSE_D_MIN, TQ3_MSE_D_MAX);
|
||||||
|
|
||||||
|
printf("\nTest 3: TQ4 roundtrip MSE\n");
|
||||||
|
n_fail += test_roundtrip_mse(GGML_TYPE_TURBO4_0, TQ4_MSE_D_MIN, TQ4_MSE_D_MAX);
|
||||||
|
|
||||||
|
printf("\nTest 4: Bit-pack tests\n");
|
||||||
|
n_fail += test_bitpack_deterministic(GGML_TYPE_TURBO3_0);
|
||||||
|
n_fail += test_bitpack_deterministic(GGML_TYPE_TURBO4_0);
|
||||||
|
n_fail += test_bitpack_sanity(GGML_TYPE_TURBO3_0);
|
||||||
|
n_fail += test_bitpack_sanity(GGML_TYPE_TURBO4_0);
|
||||||
|
|
||||||
|
printf("\n==============================\n");
|
||||||
|
printf("%d/%d tests passed\n", 7 - n_fail, 7);
|
||||||
|
|
||||||
|
return n_fail > 0 ? 1 : 0;
|
||||||
|
}
|
||||||
@@ -483,6 +483,12 @@ static ggml_type ggml_type_from_name(const std::string & s) {
|
|||||||
if (s == "iq4_nl") {
|
if (s == "iq4_nl") {
|
||||||
return GGML_TYPE_IQ4_NL;
|
return GGML_TYPE_IQ4_NL;
|
||||||
}
|
}
|
||||||
|
if (s == "turbo3") {
|
||||||
|
return GGML_TYPE_TURBO3_0;
|
||||||
|
}
|
||||||
|
if (s == "turbo4") {
|
||||||
|
return GGML_TYPE_TURBO4_0;
|
||||||
|
}
|
||||||
|
|
||||||
return GGML_TYPE_COUNT;
|
return GGML_TYPE_COUNT;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user