Compare commits

...

3 Commits

Author SHA1 Message Date
Jeff Bolz f9bd518a6b vulkan: make FA mask/softcap enables spec constants (#19309)
* vulkan: make FA mask/softcap enables spec constants

* don't specialize for sinks

* bump timeout a little bit
2026-02-06 08:49:58 +01:00
Georgi Gerganov 7fcf1ef45d metal : skip loading all-zero mask (#19337)
* metal : skip loading all-zero mask

* cont : minor
2026-02-06 09:25:11 +02:00
Daniel Bevenius e696cfc016 llama : rename llama-sampling to llama-sampler (#19363)
This commit addresses the TODO in llama-sampling.h to rename that header
and the implementation to llama-sampler.
2026-02-06 07:26:54 +01:00
11 changed files with 87 additions and 67 deletions
+1 -1
View File
@@ -468,7 +468,7 @@ jobs:
export GGML_VK_VISIBLE_DEVICES=0
export GGML_VK_DISABLE_F16=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4200
ctest -L main --verbose --timeout 4800
ubuntu-24-cmake-webgpu:
runs-on: ubuntu-24.04
+39 -24
View File
@@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
// scan the blocks of the mask that are not masked
// 0 - masked (i.e. full of -INF, skip)
// 1 - not masked (i.e. at least one element of the mask is not -INF)
// 2 - all zero
kernel void kernel_flash_attn_ext_blk(
constant ggml_metal_kargs_flash_attn_ext_blk & args,
device const char * mask,
@@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk(
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
// fast route
if (res == 0) {
if (simd_max(*mask_src) > -MAXHALF/2) {
res = 1;
}
}
// detailed check of the elements of the block
if ((C > NW || Q > 1) && res == 0) {
half m = -MAXHALF;
half mmin = MAXHALF;
half mmax = -MAXHALF;
FOR_UNROLL (short j = 0; j < Q; ++j) {
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
m = max(m, mask_src[ii*NW]);
mmin = min(mmin, mask_src[ii*NW]);
mmax = max(mmax, mask_src[ii*NW]);
}
mask_src += args.nb31/2;
}
if (simd_max(m) > -MAXHALF/2) {
res = 1;
mmin = simd_min(mmin);
mmax = simd_max(mmax);
if (mmax > -MAXHALF) {
if (mmin == 0.0 && mmax == 0.0) {
res = 2;
} else {
res = 1;
}
}
}
@@ -5568,9 +5571,13 @@ void kernel_flash_attn_ext_impl(
ic = 0;
}
char blk_cur = 1;
// read the mask into shared mem
if (FC_flash_attn_ext_has_mask) {
if (blk[ic0] == 0) {
blk_cur = blk[ic0];
if (blk_cur == 0) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
@@ -5578,16 +5585,22 @@ void kernel_flash_attn_ext_impl(
continue;
}
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
if (blk_cur == 1) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
const short j = jj*NSG + sgitg;
if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
if (FC_flash_attn_ext_bc_mask) {
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
} else {
sm2[j*SH + tiisg] = pm2[jj][tiisg];
}
pm2[jj] += NW;
}
} else if (blk_cur == 2) {
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
pm2[jj] += NW;
}
pm2[jj] += NW;
}
#if 0
@@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl(
}
// mqk = mqk + slope*mask
if (FC_flash_attn_ext_has_bias) {
s2 += s2_t(sm2[j*SH + tiisg])*slope;
} else {
s2 += s2_t(sm2[j*SH + tiisg]);
if (blk_cur != 2) {
if (FC_flash_attn_ext_has_bias) {
s2 += s2_t(sm2[j*SH + tiisg])*slope;
} else {
s2 += s2_t(sm2[j*SH + tiisg]);
}
}
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
+30 -26
View File
@@ -402,19 +402,19 @@ enum FaCodePath {
};
struct vk_fa_pipeline_state {
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt)
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {}
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
uint32_t HSK, HSV;
bool small_rows, small_cache;
FaCodePath path;
bool aligned;
bool f32acc;
bool use_mask_opt;
uint32_t flags;
bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) <
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt);
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
}
};
@@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
};
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector<uint32_t> {
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80.
@@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// AMD prefers loading K directly from global memory
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt};
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
};
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \
bool use_mask_opt = fa.first.use_mask_opt; \
uint32_t flags = fa.first.flags; \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
} \
} \
} \
@@ -8595,10 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt);
uint32_t flags = (use_mask_opt ? 1 : 0) |
(mask != nullptr ? 2 : 0) |
(logit_softcap != 0 ? 4 : 0);
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
vk_pipeline pipeline = nullptr;
@@ -8678,18 +8694,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
const uint32_t n_head_kv = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -8703,7 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
if (use_mask_opt)
{
@@ -127,7 +127,7 @@ void main() {
continue;
}
// Only load if the block is not all zeros
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
@@ -181,7 +181,7 @@ void main() {
}
}
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
@@ -189,7 +189,7 @@ void main() {
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) {
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
@@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
layout (constant_id = 9) const bool USE_MASK_OPT = false;
layout (constant_id = 9) const uint32_t Flags = 0;
const bool USE_MASK_OPT = (Flags & 1) != 0;
const bool MASK_ENABLE = (Flags & 2) != 0;
const bool LOGIT_SOFTCAP = (Flags & 4) != 0;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -60,7 +64,6 @@ layout (push_constant) uniform parameter {
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
@@ -160,7 +160,7 @@ void main() {
mask_cache[idx] = f16vec4(0);
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
@@ -303,7 +303,7 @@ void main() {
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
@@ -314,7 +314,7 @@ void main() {
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / (Br / 4);
uint32_t r = (idx + tid) % (Br / 4);
@@ -155,7 +155,7 @@ void main() {
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
mask_opt_idx = j / 16;
@@ -197,14 +197,14 @@ void main() {
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
if (LOGIT_SOFTCAP) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
if (MASK_ENABLE) {
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
+1 -1
View File
@@ -31,7 +31,7 @@ add_library(llama
llama-model-saver.cpp
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp
llama-sampler.cpp
llama-vocab.cpp
unicode-data.cpp
unicode.cpp
+1 -1
View File
@@ -2,7 +2,7 @@
#include "llama-impl.h"
#include "llama-vocab.h"
#include "llama-sampling.h"
#include "llama-sampler.h"
#include <cmath>
#include <algorithm>
@@ -1,4 +1,4 @@
#include "llama-sampling.h"
#include "llama-sampler.h"
#include "llama-impl.h"
#include "llama-vocab.h"
@@ -1,7 +1,5 @@
#pragma once
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
#include "llama.h"
#include <vector>