Compare commits

...

13 Commits

Author SHA1 Message Date
Jeff Bolz d6d0ce8215 vulkan: reduce iq1 shared memory usage for mul_mm (#24287) 2026-06-09 13:27:38 +02:00
Ruben Ortlam b4e3dc613b vulkan: add v_dot2_f32_f16 support in matrix-matrix multiplication and Flash Attention (#24123)
* vulkan: add support for valve fp16 dot2 extension

* use macro for dot2 path choice

* properly check for the feature

* add dot_product abstraction to reduce preprocessor branching
2026-06-09 13:27:04 +02:00
Nick Towle ae735b1314 ui: Fix excessive style recalculation on hover (#24243) 2026-06-09 12:52:20 +02:00
Xuan-Son Nguyen 9682e351b8 mtmd: refactor video subproc handling (#24316)
* mtmd: refactor video subproc handling

* Update tools/mtmd/mtmd-helper.cpp

Co-authored-by: Mikko Juola <mikjuo@gmail.com>

---------

Co-authored-by: Mikko Juola <mikjuo@gmail.com>
2026-06-09 13:15:12 +03:00
jacekpoplawski 1e912561dd server: log prompts to directory (#22031)
* server: log prompts to directory

Add `--log-prompts-dir` to write each prompt to a separate text file in
the specified directory.

* Apply suggestion from @ngxson

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-06-09 12:09:07 +02:00
Pascal efbacf8d21 ui: fix mobile chat form overflow and bust stale bundle cache (#24158) 2026-06-09 11:12:58 +02:00
Pascal 26021699bc ggml : add GGML_OP_COL2IM_1D (#24206)
* cpu: add GGML_OP_COL2IM_1D

Add the overlap-add (scatter-add) step of a 1D transposed convolution.
A ConvTranspose1d factorizes as a GEMM followed by col2im: a weight
pre-permuted to [IC, K*OC] is contracted against the [IC, T_in] input
with mul_mat to produce a column matrix [K*OC, T_in], and col2im_1d
scatters those columns back into the [T_out, OC] signal, with
T_out = (T_in - 1)*s0 + K - 2*p0.

Keeping the contraction as a plain mul_mat leaves the heavy work on the
optimized (and quantizable) matmul kernels, so col2im_1d only does the
cheap overlap-add.

CPU uses a gather formulation parallelized over output channels,
supporting F32, F16 and BF16 with an F32 accumulator.

* tests: add backend coverage for GGML_OP_COL2IM_1D

Add test_col2im_1d next to the conv_transpose_1d cases, covering F32,
F16 and BF16 across eight geometries: the canonical kernel = 2*stride
DAC upsampling shape, overlap, no overlap, cropping (p0 = 1 and
p0 = stride/2), kernel < stride with zeroed gaps, kernel not a
multiple of stride, and a single column unfold.

Perf mode gets three real vocoder stage shapes reporting memory
bandwidth. max_nmse_err relaxes to 5e-4 for F16 and BF16.

* cpu: harden GGML_OP_COL2IM_1D

ggml_col2im_1d validates s0, oc, p0 and input contiguity at graph
build time, before the oc division, protecting every backend at once.
The kernel asserts the contiguity its flat indexing assumes and its
doc states the full output length including the crop term.

The kernel parallelizes over the time axis: the split stays balanced
down to OC = 1, where the previous channel split was single threaded.
Values are bit identical on the three real vocoder chains, two out of
three improve.

* tests: extend the GGML_OP_COL2IM_1D grid

The eval grid grows to eleven geometries: OC = 1 (mono output stage),
K = 1 with stride > 1 (sparse scatter, every gap position zeroed) and
a crop down to T_out = 2 where all the gather bounds act at once.

* tests: add col2im_1d equivalence test

tests/test-col2im-1d.cpp proves mul_mat + col2im_1d matches the
native ggml_conv_transpose_1d on the CPU backend, F32 bit exact, F16
and BF16 through casts of the column matrix. test-backend-ops cannot
cover this for a CPU only op since the CPU backend is its own
reference there.

* rpc: bump protocol patch version for GGML_OP_COL2IM_1D

GGML_OP_COUNT goes from 96 to 97 with the new op, which trips the
static_assert in ggml-rpc.h. Bump RPC_PROTO_PATCH_VERSION since the
op is appended and no existing op code shifts.
2026-06-09 12:01:37 +03:00
fiesh 961e9a3e46 server : do not clear slots without unified KV cache (#24190)
* Always export idle slots to RAM

Without this, a slot's VRAM cache may not be written to RAM.  If this
slot happens to be busy then later on, this triggers needless
preprocessing in another slot.

* cont : clean-up

---------

Co-authored-by: Christoph Weiss <weiss@wsoptics.de>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-09 10:45:16 +03:00
Sigbjørn Skjæret f0152efe40 models : fix plamo2 attention_key/value_length regression (#24317) 2026-06-09 10:26:44 +03:00
Yash Raj Pandey fd3271e0b4 ggml-cpu : fix rms_norm_back wrong output under in-place aliasing (#24305)
* ggml-cpu : fix rms_norm_back wrong output under in-place aliasing

* cont : clean-up comment

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-09 10:24:27 +03:00
ravel7524 e3471b3e73 Remove case for GGML_TYPE_Q4_K in mvvq.cu (#23528) 2026-06-09 07:46:23 +02:00
Reese Levine 3ac3c20c96 ggml-webgpu: Add clang-format job (#24308)
* Add clang-format job

* try local formatting
2026-06-08 20:54:24 -07:00
Masashi Yoshimura 1e1aca09da ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (#24225)
* ggml-webgpu: Improve prefill speeds + refactor matmul for quants

* Fixes for editroconfig checker
2026-06-08 15:19:56 -07:00
30 changed files with 923 additions and 676 deletions
+23
View File
@@ -35,6 +35,29 @@ env:
LLAMA_ARG_LOG_TIMESTAMPS: 1
jobs:
format:
runs-on: ubuntu-24.04
steps:
- name: Clone
uses: actions/checkout@v6
- name: Install clang-format 22
run: |
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key |
sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
sudo add-apt-repository -y \
"deb http://apt.llvm.org/noble/ llvm-toolchain-noble-22 main"
sudo apt-get update
sudo apt-get install -y clang-format-22
- name: Check formatting
run: |
find ggml/src/ggml-webgpu \
-type f \( -name '*.cpp' -o -name '*.hpp' -o -name '*.h' \) \
-print0 |
xargs -0 clang-format-22 --dry-run --Werror
macos:
runs-on: macos-latest
+8 -1
View File
@@ -1360,7 +1360,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
add_opt(common_arg(
{"--cache-idle-slots"},
{"--no-cache-idle-slots"},
"save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)",
"save idle slots to the prompt cache on new task, and clear them when using unified KV (default: enabled, requires cache-ram)",
[](common_params & params, bool value) {
params.cache_idle_slots = value;
}
@@ -3333,6 +3333,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str());
}
).set_env("LLAMA_ARG_LOG_FILE"));
add_opt(common_arg(
{"--log-prompts-dir"}, "PATH",
"Log prompts to directory (only used for debugging, default: disabled)",
[](common_params & params, const std::string & value) {
params.path_prompts_log_dir = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--log-colors"}, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
+1
View File
@@ -489,6 +489,7 @@ struct common_params {
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT
std::string path_prompts_log_dir = ""; // directory with logged prompts // NOLINT
// llama-debug specific options
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
+2 -2
View File
@@ -8,10 +8,10 @@ extern "C" {
#define RPC_PROTO_MAJOR_VERSION 4
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define RPC_PROTO_PATCH_VERSION 1
#ifdef __cplusplus
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
#endif
#define GGML_RPC_MAX_SERVERS 16
+11
View File
@@ -535,6 +535,7 @@ extern "C" {
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_IM2COL_3D,
GGML_OP_COL2IM_1D,
GGML_OP_CONV_2D,
GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
@@ -2007,6 +2008,16 @@ extern "C" {
int d1, // dilation dimension 1
bool is_2D);
// col2im_1d: scatter-add GEMM columns back to 1D signal
// a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC)
// result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0
GGML_API struct ggml_tensor * ggml_col2im_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // columns [K*OC, T_in]
int s0, // stride
int oc, // output channels
int p0); // padding to crop from both sides
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
+5
View File
@@ -1912,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_im2col_3d(params, tensor);
} break;
case GGML_OP_COL2IM_1D:
{
ggml_compute_forward_col2im_1d(params, tensor);
} break;
case GGML_OP_CONV_2D:
{
ggml_compute_forward_conv_2d(params, tensor);
@@ -2343,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_COL2IM_1D:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_2D:
{
+78 -6
View File
@@ -4008,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32(
// dx := scale(dx, rrms)
float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
ggml_vec_cpy_f32 (ne00, dx, x);
// ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
ggml_vec_acc_f32 (ne00, dx, dz);
ggml_vec_scale_f32(ne00, dx, rrms);
// dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms
// note: https://github.com/ggml-org/ggml/issues/1491
const float scale_x = (float) (-sum_xdz) / sum_eps;
for (int64_t i00 = 0; i00 < ne00; i00++) {
dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms;
}
}
}
}
@@ -6730,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
return (coord + size) % size; // adding size avoids negative number weirdness
}
// ggml_compute_forward_col2im_1d
//
// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC]
// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs.
// Parallelized over the time axis so the split stays balanced whatever OC is.
// Supports F32, F16, BF16 input/output (same type), F32 accumulator.
template <typename elem_t>
static void ggml_compute_forward_col2im_1d_impl(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src = dst->src[0]; // [K*OC, T_in]
GGML_ASSERT(ggml_is_contiguous(src));
GGML_ASSERT(ggml_is_contiguous(dst));
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t OC = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int64_t K_OC = src->ne[0];
const int64_t T_in = src->ne[1];
const int64_t K = K_OC / OC;
const int64_t T_out = dst->ne[0];
const elem_t * col_data = (const elem_t *) src->data;
elem_t * dst_data = (elem_t *) dst->data;
const int ith = params->ith;
const int nth = params->nth;
// Parallelize over the time axis: the split stays balanced whatever OC is,
// down to OC = 1 for mono audio, and threads read disjoint column bands
const int64_t dr = (T_out + nth - 1) / nth;
const int64_t it0 = dr * ith;
const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out;
for (int64_t oc = 0; oc < OC; oc++) {
for (int64_t t_out = it0; t_out < it1; t_out++) {
const int64_t t_abs = t_out + p0; // absolute position in uncropped signal
// Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K
int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s)
if (t_in_min < 0) t_in_min = 0;
int64_t t_in_max = t_abs / s0;
if (t_in_max >= T_in) t_in_max = T_in - 1;
float sum = 0.0f;
for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
int64_t k = t_abs - t_in * s0;
if (k >= 0 && k < K) {
// col layout: [K*OC, T_in], element (oc*K+k, t_in)
sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]);
}
}
// dst layout: [T_out, OC], element (t_out, oc)
dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum);
}
}
}
void ggml_compute_forward_col2im_1d(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->src[0]->type) {
case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break;
case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break;
case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break;
default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type);
}
}
// ggml_compute_forward_conv_2d
+1
View File
@@ -68,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-1
View File
@@ -411,7 +411,6 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
return 8;
case GGML_TYPE_Q6_K:
return 2;
+77 -11
View File
@@ -113,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
#endif
#if !defined(VK_VALVE_shader_mixed_float_dot_product)
#define VK_VALVE_shader_mixed_float_dot_product 1
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1
#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product"
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000)
typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE {
VkStructureType sType;
void* pNext;
VkBool32 shaderMixedFloatDotProductFloat16AccFloat32;
VkBool32 shaderMixedFloatDotProductFloat16AccFloat16;
VkBool32 shaderMixedFloatDotProductBFloat16Acc;
VkBool32 shaderMixedFloatDotProductFloat8AccFloat32;
} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE;
#endif
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -705,6 +720,8 @@ struct vk_device_struct {
bool coopmat2_bf16_support {};
bool coopmat2_decode_vector;
bool dot2_f16 {};
bool pipeline_executable_properties_support {};
size_t idx;
@@ -3377,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
switch (src0_type) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
lut_size = 2*2048 + 4*2048;
// Regular matmul uses the compact uint16_t IQ1 grid; the expanded
// uint32_t grid is only enabled for the q8_1/int-dot vector path.
lut_size = 2*2048;
break;
case GGML_TYPE_IQ2_XXS:
lut_size = 8*256;
@@ -3920,8 +3939,13 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
} else {
if (device->fp16) {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
if (device->dot2_f16) {
if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; }
else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; }
} else {
if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
}
} else {
spv_data = flash_attn_f32_f16_fp32_data;
spv_size = flash_attn_f32_f16_fp32_len;
@@ -4215,7 +4239,23 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
// bf16 scalar path promotes to f32, no dot2 variant
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat ## ID ## _m[TYPE]) \
@@ -4250,7 +4290,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -4258,7 +4298,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -4298,8 +4337,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -4344,8 +4382,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -4390,6 +4427,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
#undef CREATE_MM_NODOT2
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
@@ -5453,6 +5491,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->integer_dot_product = false;
device->shader_64b_indexing = false;
bool bfloat16_support = false;
bool dot2_f16_support = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -5495,6 +5534,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_DOT2")) {
dot2_f16_support = true;
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
pipeline_executable_properties_support = true;
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
@@ -5802,6 +5844,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
}
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
if (dot2_f16_support) {
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
last_struct = (VkBaseOutStructure *)&dot2_features;
device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product");
}
VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
if (pipeline_executable_properties_support) {
@@ -5836,6 +5886,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->bf16 = false;
#endif
device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
@@ -6250,6 +6302,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
bool coopmat2_decode_vector_support = false;
bool integer_dot_product = false;
bool bfloat16_support = false;
bool dot2_f16_support = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -6279,6 +6332,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
bfloat16_support = true;
#endif
} else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_DOT2")) {
dot2_f16_support = true;
}
}
@@ -6369,6 +6425,13 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
}
VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
if (dot2_f16_support) {
last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
last_struct = (VkBaseOutStructure *)&dot2_features;
}
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
@@ -6415,9 +6478,12 @@ static void ggml_vk_print_gpu_info(size_t idx) {
: coopmat_support ? "KHR_coopmat"
: "none";
bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@@ -0,0 +1,27 @@
#ifdef DOT2_F16
#extension GL_EXT_spirv_intrinsics : require
spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"],
capabilities = [6912], id = 6916)
float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc);
ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc))));
}
ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) {
return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc)));
}
#else
ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y),
fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc))));
}
ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) {
return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc));
}
#endif
@@ -21,6 +21,7 @@
#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "dot_product_funcs.glsl"
#include "flash_attn_base.glsl"
#include "flash_attn_dequant.glsl"
@@ -318,7 +319,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf));
Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]);
}
}
}
@@ -341,7 +342,7 @@ void main() {
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf));
Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]);
}
}
}
@@ -4,6 +4,7 @@
#extension GL_EXT_integer_dot_product : require
#define MMQ
#define NEEDS_IQ1S_GRID_GPU
#define B_TYPE block_q8_1_x4
#include "mul_mat_vec_base.glsl"
@@ -29,6 +29,7 @@
#endif
#include "types.glsl"
#include "dot_product_funcs.glsl"
#ifndef LOAD_VEC_A
#define LOAD_VEC_A 1
@@ -329,15 +330,8 @@ void main() {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#if defined(DATA_A_F32) || defined(DATA_A_F16)
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
#else
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
#endif
sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x);
sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y);
}
}
}
@@ -598,9 +598,10 @@ const uint[1024] iq1s_grid_const = {
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
};
#if defined(NEEDS_IQ1S_GRID_GPU)
// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
// and 0xF0F0F0F0).
// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path.
const uint32_t[2048] iq1s_grid_gpu_const = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
@@ -859,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = {
0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
};
#endif
shared uint16_t iq1s_grid[2048];
#if defined(NEEDS_IQ1S_GRID_GPU)
shared uint32_t iq1s_grid_gpu[2048];
#endif
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
@@ -875,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize)
iq1s_grid[2*idx+1] = g.y;
}
}
#if defined(NEEDS_IQ1S_GRID_GPU)
[[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
uint idx = i + gl_LocalInvocationIndex.x;
if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
}
}
#endif
barrier();
}
#endif
@@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
// disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability)
if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) {
cmd.push_back("-O");
}
@@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
generate_dep_file = false;
}
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) {
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
std::string dot2_sfx = dot2 ? "_dot2" : "";
std::map<std::string, std::string> base_dict;
std::string shader_name = "matmul";
@@ -463,6 +465,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
#endif
if (dot2) {
base_dict["DOT2_F16"] = "1";
}
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
@@ -528,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
@@ -553,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (!(coopmat || coopmat2))
#endif
{
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
if (!dot2) {
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
}
@@ -584,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
// Integer dot mmq performs better with f32 accumulators
if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
// Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2)
if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
@@ -613,6 +621,10 @@ void process_shaders() {
matmul_shaders(true, matmul_id_type, false, false, false);
matmul_shaders(true, matmul_id_type, false, false, true);
// dot2 variants (scalar fp16 only)
matmul_shaders(true, matmul_id_type, false, false, false, true);
matmul_shaders(true, matmul_id_type, false, false, true, true);
if (matmul_id_type != MatMulIdType::DEFAULT) {
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
// Coopmat, fp32acc and fp16acc
@@ -660,6 +672,12 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
if (fp16) {
string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
string_to_spv("flash_attn_f32_f16", "flash_attn.comp",
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8");
@@ -644,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
const uint32_t offset_elems =
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
ggml_type_size(K->type));
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
}
@@ -655,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
}
inline bool ggml_webgpu_flash_attn_kv_direct(
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
const ggml_tensor * K,
const ggml_tensor * V,
uint32_t kv_direct_align) {
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
}
@@ -671,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co
key.dst_type = context.dst->type;
key.head_dim_qk = (uint32_t) context.src0->ne[0];
key.head_dim_v = (uint32_t) context.src2->ne[0];
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
key.has_mask = context.src3 != nullptr;
key.has_sinks = context.src4 != nullptr;
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
return key;
}
@@ -1727,7 +1730,7 @@ class ggml_webgpu_shader_lib {
key.type = context.dst->type;
key.d_state = (int) context.src0->ne[0];
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
ggml_webgpu_tensor_overlap(context.src1, context.src5);
ggml_webgpu_tensor_overlap(context.src1, context.src5);
auto it = ssm_scan_pipelines.find(key);
if (it != ssm_scan_pipelines.end()) {
+3 -3
View File
@@ -4253,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
const uint32_t q_tile =
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const bool kv_direct = use_subgroup_matrix ?
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
false;
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
@@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
#endif // INIT_SRC0_SHMEM_Q1_0
#ifdef INIT_SRC0_SHMEM_Q4_0
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1)
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
#else
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
#endif
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let tile_m = block_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let block_k = block_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#ifdef INIT_SRC0_SHMEM_Q4_0
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_0
#elif INIT_SRC0_SHMEM_Q4_1
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q4_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 20u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q4_1
#ifdef INIT_SRC0_SHMEM_Q5_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 22u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
// tile_k is defined as 32u, so blocks_k ends up being 1 always
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q5_0
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
let d = load_f16_at_src0(block_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
@@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_0
#elif INIT_SRC0_SHMEM_Q5_1
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
#ifdef INIT_SRC0_SHMEM_Q5_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 24u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K / BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u);
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
let q_packed = load_u32_at_src0_aligned(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
@@ -277,461 +201,306 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q5_1
#ifdef INIT_SRC0_SHMEM_Q8_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 34u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
#elif INIT_SRC0_SHMEM_Q8_0
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
let d = load_f16_at_src0(block_byte_base);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_0
#elif INIT_SRC0_SHMEM_Q8_1
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
let d = f16(dm[0]);
let m = f16(dm[1]);
#ifdef INIT_SRC0_SHMEM_Q8_1
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 36u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let m = load_f16_at_src0(block_byte_base + 2u);
// store NQ(16) weights
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d + m;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
}
}
#elif INIT_SRC0_SHMEM_MXFP4
let block_byte_base = src0_idx * 17u;
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
let e = ldexp(1.0, i32(eu8) - 128);
// load NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
#endif
}
}
}
#endif // INIT_SRC0_SHMEM_Q8_1
#endif
// k-quants
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
const BLOCK_SIZE = 256u;
const NQ = 4u;
fn store_shmem_kquants(val: vec4<f16>, idx: u32) {
shmem[idx] = val.x;
shmem[idx + 1] = val.y;
shmem[idx + 2] = val.z;
shmem[idx + 3] = val.w;
}
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
}
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
#ifdef INIT_SRC0_SHMEM_Q2_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 84u;
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
let scales_byte_base = block_byte_base;
let qs_byte_base = block_byte_base + 16u;
let dm_byte_base = block_byte_base + 80u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
// Use standard thread layout instead of lane/row_group
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(d_packed[0]);
let dmin = f16(d_packed[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
// whole 2 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u),
f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u),
);
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block);
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let dl = d * f16(scale & 0xFu);
let ml = dmin * f16(scale >> 4u);
let d = load_f16_at_src0(block_byte_base + 80u);
let dmin = load_f16_at_src0(block_byte_base + 82u);
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
#elif INIT_SRC0_SHMEM_Q3_K
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
let hmask_byte_base = block_byte_base + 0u;
let qs_byte_base = block_byte_base + 32u;
let scales_byte_base = block_byte_base + 96u;
// Decode the element at position k_in_block
let block_of_32 = k_in_block / 32u;
let pos_in_32 = k_in_block % 32u;
let d_all = load_f16_at_src0(block_byte_base + 108u);
let q_b_idx = (block_of_32 / 4u) * 32u;
let shift = (block_of_32 % 4u) * 2u;
let k = (pos_in_32 / 16u) * 16u;
let l = pos_in_32 % 16u;
let chunk = k_in_block / 128u;
let pos_in_chunk = k_in_block % 32u;
let sub_block = k_in_block / 16u;
let shift_phase = (k_in_block % 128u) / 32u;
let is = k_in_block / 16u;
let hmask_block = pos_in_chunk;
let hmask_shift_phase = k_in_block / 32u;
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
let sc = get_byte(sc_packed, is % 4u);
// low 2 bits (4 elems)
let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block);
let q_lo2_vec4 = vec4<f16>(
f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u),
f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u)
);
let dl = d * f16(sc & 0xFu);
let ml = dmin * f16(sc >> 4u);
// high 1 bit (4 elems)
let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk);
let q_hi1_vec4 = vec4<f16>(
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)),
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u))
);
let q_idx = q_b_idx + k + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 3u;
let q_vec4 = q_lo2_vec4 - q_hi1_vec4;
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q2_K
let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu;
let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u;
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
#ifdef INIT_SRC0_SHMEM_Q3_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 110u;
store_shmem_kquants(dl * q_vec4, elem_idx);
#elif INIT_SRC0_SHMEM_Q4_K
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qs_byte_base = block_byte_base + 16u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base + 108u);
// Load and unpack scales
let kmask1: u32 = 0x03030303u;
let kmask2: u32 = 0x0f0f0f0fu;
var scale_vals: array<u32, 4>;
for (var i: u32 = 0u; i < 4u; i++) {
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
}
var tmp: u32 = scale_vals[2];
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
// Load hmask and qs arrays
var hmask_vals: array<u32, 8>;
for (var i: u32 = 0u; i < 8u; i++) {
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
}
var qs_vals: array<u32, 16>;
for (var i: u32 = 0u; i < 16u; i++) {
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
}
let half = k_in_block / 128u; // 0 or 1
let pos_in_half = k_in_block % 128u; // 0-127
let shift_group = pos_in_half / 32u; // 0-3
let pos_in_32 = pos_in_half % 32u; // 0-31
let k_group = pos_in_32 / 16u; // 0 or 1
let l = pos_in_32 % 16u; // 0-15
let q_b_idx = half * 32u; // 0 or 32
let shift = shift_group * 2u; // 0, 2, 4, 6
let k = k_group * 16u; // 0 or 16
let is = k_in_block / 16u; // 0-15
// m increments every 32 elements across entire 256 element block
let m_shift = k_in_block / 32u; // 0-7
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
let sc = get_byte(scale_vals[is / 4u], is % 4u);
let dl = d * (f16(sc) - 32.0);
let q_idx = q_b_idx + k + l;
let hm_idx = k + l;
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
let qs_val = (q_byte >> shift) & 3u;
let q_val = (f16(qs_val) - f16(hm)) * dl;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q3_K
#ifdef INIT_SRC0_SHMEM_Q4_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 144u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// Map k_in_block to loop structure:
// Outer loop over 64-element groups (alternating q_b_idx)
// Inner loop over 2 shifts per group
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// whole 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q5_K
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
let dm_byte_base = block_byte_base + 0u;
let scale_byte_base = block_byte_base + 4u;
let qh_byte_base = block_byte_base + 16u;
let qs_byte_base = block_byte_base + 48u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let qs_val = (q_byte >> shift) & 0xFu;
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
let d = f16(dm[0]);
let dmin = f16(dm[1]);
let q_val = f16(qs_val) * dl - ml;
shmem[elem_idx] = q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q4_K
let chunk = k_in_block / 64u;
let pos_in_chunk = (k_in_block % 64u) % 32u;
let sub_block = k_in_block / 32u;
let shift_phase = sub_block & 1u;
#ifdef INIT_SRC0_SHMEM_Q5_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 176u;
let qh_block = k_in_block % 32u;
let qh_shift_phase = sub_block;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
// low 4 bits (4 elems)
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
let qs_lo4_vec4 = vec4<f16>(
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
);
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let dmin = load_f16_at_src0(block_byte_base + 2u);
// The original loop processes elements in groups of 64
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
// But u increments EVERY 32 elements (after each l loop)
let group_of_64 = k_in_block / 64u; // 0-3
let pos_in_64 = k_in_block % 64u; // 0-63
let shift_group = pos_in_64 / 32u; // 0 or 1
let l = pos_in_64 % 32u; // 0-31
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
let shift = shift_group * 4u; // 0 or 4
let is = k_in_block / 32u; // 0-7
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
let u_shift = k_in_block / 32u; // 0-7
let u: u32 = 1u << u_shift;
// high 1 bit (4 elems)
let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block);
let qh_vec4 = vec4<f16>(
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)),
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u))
);
var sc: u32;
var mn: u32;
let scale_base = block_byte_base + 4u;
if (is < 4u) {
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
if (sub_block < 4u) {
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = sc_byte & 63u;
mn = min_byte & 63u;
} else {
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
}
let dl = d * f16(sc);
let ml = dmin * f16(mn);
let q_idx = q_b_idx + l;
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
#elif INIT_SRC0_SHMEM_Q6_K
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
let ql_byte_base = block_byte_base;
let qh_byte_base = block_byte_base + 128u;
let scales_byte_base = block_byte_base + 192u;
let d_byte_base = block_byte_base + 208u;
let q_byte = get_byte(q_packed, q_idx % 4u);
let d = load_f16_at_src0(d_byte_base);
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
let chunk = k_in_block / 128u;
let ql_pos_in_chunk = (k_in_block % 128u) % 64u;
let qh_pos_in_chunk = (k_in_block % 128u) % 32u;
let sub_block = k_in_block / 16u;
let ql_shift_phase = (k_in_block % 128u) / 64u;
let qh_shift_phase = (k_in_block % 128u) / 32u;
let qh_byte = get_byte(qh_packed, l % 4u);
// low 4 bits (4 elems)
let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk);
let ql_lo4_vec4 = vec4<u32>(
(ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu,
(ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu
);
let qs_val = (q_byte >> shift) & 0xFu;
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
// hi 2 bits (4 elems)
let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk);
let qh_hi2_vec4 = vec4<u32>(
((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u,
((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u,
);
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
shmem[elem_idx] = q_val;
let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0);
let scale_byte = scales_byte_base + 1u * sub_block;
let scale_word = load_u32_at_src0_aligned(scale_byte);
let scale = get_byte_i32(scale_word, scale_byte & 3u);
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
#endif
}
}
#endif // INIT_SRC0_SHMEM_Q5_K
#ifdef INIT_SRC0_SHMEM_Q6_K
const BLOCK_SIZE = 256u;
const BLOCK_SIZE_BYTES = 210u;
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
let tile_m = elem_idx / TILE_K;
let tile_k = elem_idx % TILE_K;
let global_m = offset_m + tile_m;
let global_k = k_outer + tile_k;
if (global_m >= params.m || global_k >= params.k) {
shmem[elem_idx] = f16(0.0);
continue;
}
let block_k = global_k / BLOCK_SIZE;
let k_in_block = global_k % BLOCK_SIZE;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let half = k_in_block / 128u;
let pos_in_half = k_in_block % 128u;
let quarter = pos_in_half / 32u;
let l = pos_in_half % 32u;
let ql_b_idx = half * 64u;
let qh_b_idx = half * 32u;
let sc_b_idx = half * 8u;
// Load only ql13 word needed
let ql13_flat = ql_b_idx + l;
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
let ql13_b = get_byte(ql13, 0u);
// Load only ql24 word needed
let ql24_flat = ql_b_idx + l + 32u;
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
let ql24_b = get_byte(ql24, 0u);
// Load only qh word needed
let qh_flat = qh_b_idx + l;
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
let qh_b = get_byte(qh, 0u);
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
// Load only the scale word needed
let is = l / 16u;
let sc_idx = sc_b_idx + is + quarter * 2u;
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
let sc_val = get_byte_i32(sc, 0u);
let d = load_f16_at_src0(block_byte_base + 208u);
var q_val: f16;
if (quarter == 0u) {
q_val = q1;
} else if (quarter == 1u) {
q_val = q2;
} else if (quarter == 2u) {
q_val = q3;
} else {
q_val = q4;
}
shmem[elem_idx] = d * f16(sc_val) * q_val;
}
}
#endif // INIT_SRC0_SHMEM_Q6_K
#endif // k-quants
#ifdef INIT_SRC0_SHMEM_IQ4_NL
const BLOCK_SIZE = 32u;
@@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
}
}
#endif // INIT_SRC0_SHMEM_IQ3_S
#ifdef INIT_SRC0_SHMEM_MXFP4
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 17u;
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
override BLOCKS_K = TILE_K/BLOCK_SIZE;
const NQ = 16u;
const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let blck_idx = i / BLOCK_SIZE;
let block_offset = (i % BLOCK_SIZE) / NQ;
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
let tile_m = blck_idx / BLOCKS_K;
let global_m = offset_m + tile_m;
let block_k = blck_idx % BLOCKS_K;
let global_block_k = k_outer / BLOCK_SIZE + block_k;
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
let e = ldexp(1.0, i32(eu8) - 128);
// store NQ(16) weights
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
let q_packed = load_u32_at_src0(q_byte_offset);
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
}
}
}
}
}
#endif // INIT_SRC0_SHMEM_MXFP4
+39 -2
View File
@@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"IM2COL",
"IM2COL_BACK",
"IM2COL_3D",
"COL2IM_1D",
"CONV_2D",
"CONV_3D",
"CONV_2D_DW",
@@ -1080,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1141,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"im2col(x)",
"im2col_back(x)",
"im2col_3d(x)",
"col2im_1d(x)",
"conv_2d(x)",
"conv_3d(x)",
"conv_2d_dw(x)",
@@ -1190,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)",
};
static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4541,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph(
return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
}
// ggml_col2im_1d
struct ggml_tensor * ggml_col2im_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
int s0,
int oc,
int p0) {
GGML_ASSERT(ggml_is_matrix(a));
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16);
GGML_ASSERT(s0 > 0);
GGML_ASSERT(oc > 0);
GGML_ASSERT(p0 >= 0);
const int64_t K_OC = a->ne[0];
const int64_t T_in = a->ne[1];
const int64_t K = K_OC / oc;
const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0;
GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks
GGML_ASSERT(K > 0 && T_out > 0);
const int64_t ne[4] = { T_out, oc, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne);
int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_COL2IM_1D;
result->src[0] = a;
return result;
}
// ggml_conv_transpose_1d
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+5 -1
View File
@@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Load attention parameters
ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false);
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false);
for (uint32_t i = 0; i < hparams.n_layer(); ++i) {
hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0;
}
@@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % n_heads == 0);
GGML_ASSERT(n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+1
View File
@@ -265,6 +265,7 @@ if (NOT GGML_BACKEND_DL)
llama_build_and_test(test-quantize-fns.cpp)
llama_build_and_test(test-quantize-perf.cpp)
llama_build_and_test(test-rope.cpp)
llama_build_and_test(test-col2im-1d.cpp)
endif()
# libmtmd
+53
View File
@@ -5098,6 +5098,39 @@ struct test_conv_transpose_1d : public test_case {
}
};
// GGML_OP_COL2IM_1D
struct test_col2im_1d : public test_case {
const ggml_type type;
const int64_t K; // kernel size
const int64_t OC; // output channels
const int64_t T_in; // input length (number of columns)
const int s0; // stride
const int p0; // padding cropped from both sides
std::string vars() override {
return VARS_TO_STR6(type, K, OC, T_in, s0, p0);
}
double max_nmse_err() override {
return type == GGML_TYPE_F32 ? 1e-7 : 5e-4;
}
test_col2im_1d(ggml_type type = GGML_TYPE_F32,
int64_t K = 4, int64_t OC = 3, int64_t T_in = 7,
int s0 = 2, int p0 = 0)
: type(type), K(K), OC(OC), T_in(T_in), s0(s0), p0(p0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * cols = ggml_new_tensor_2d(ctx, type, K*OC, T_in);
ggml_set_name(cols, "cols");
ggml_tensor * out = ggml_col2im_1d(ctx, cols, s0, (int) OC, p0);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_CONV_TRANSPOSE_2D
struct test_conv_transpose_2d : public test_case {
// Dimensions
@@ -8013,6 +8046,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16}) {
// ConvTranspose1d expressed as mul_mat + col2im (DAC decoder upsampling)
test_cases.emplace_back(new test_col2im_1d(type, 16, 32, 197, 8, 0)); // kernel = 2*stride
test_cases.emplace_back(new test_col2im_1d(type, 4, 3, 7, 2, 0));
test_cases.emplace_back(new test_col2im_1d(type, 1, 5, 13, 1, 0)); // stride 1, no overlap
test_cases.emplace_back(new test_col2im_1d(type, 6, 4, 11, 3, 1)); // with cropping
test_cases.emplace_back(new test_col2im_1d(type, 2, 3, 9, 3, 0)); // kernel < stride, gap positions are zeroed
test_cases.emplace_back(new test_col2im_1d(type, 5, 4, 11, 2, 0)); // kernel not a multiple of stride, alternating overlap
test_cases.emplace_back(new test_col2im_1d(type, 8, 4, 13, 4, 2)); // padding = stride/2 (DAC causal cropping)
test_cases.emplace_back(new test_col2im_1d(type, 4, 3, 1, 2, 0)); // single column, pure kernel unfold
test_cases.emplace_back(new test_col2im_1d(type, 16, 1, 197, 8, 0)); // OC = 1, mono output stage
test_cases.emplace_back(new test_col2im_1d(type, 1, 5, 13, 3, 0)); // K = 1 with stride > 1, sparse scatter
test_cases.emplace_back(new test_col2im_1d(type, 8, 2, 3, 2, 5)); // cropping eats most of the signal, T_out = 2
}
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_conv_transpose_2d({3, 2, 3, 1}, {2, 2, 1, 3}, 1, kernel_type));
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
@@ -9366,6 +9414,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_conv_transpose_2d({10, 10, 9, 1}, {3, 3, 1, 9}, 2, kernel_type));
}
// Memory bound overlap-add of the GEMM + col2im_1d transposed conv path, real vocoder stage shapes
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F32, 16, 512, 2048, 8, 0));
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F32, 4, 128, 65536, 2, 0));
test_cases.emplace_back(new test_col2im_1d(GGML_TYPE_F16, 16, 512, 2048, 8, 0));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+159
View File
@@ -0,0 +1,159 @@
// test-col2im-1d.cpp: validate GGML_OP_COL2IM_1D against ggml_conv_transpose_1d.
//
// A ConvTranspose1d factorizes as a GEMM followed by an overlap-add:
// conv_transpose_1d(w, x) equals col2im_1d(mul_mat(w_perm, x_t), s0, OC, p0)
// with w_perm the [IC, K*OC] permutation of the [K, OC, IC] kernel and x_t the
// [IC, T_in] transpose of the [T_in, IC] input. The test derives both alternative
// layouts from one logical weight and one logical input with graph ops only
// (permute + cont + reshape), runs the two paths on the CPU backend, and compares
// them in F32. The F16 and BF16 kernels are exercised by casting the column
// matrix before the scatter. Cropping (p0 > 0) is checked against the shifted
// slice of the uncropped reference, which conv_transpose_1d cannot express.
#include "ggml.h"
#include "ggml-cpu.h"
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <vector>
// One geometry: kernel size, output channels, input length, stride, crop
struct col2im_case {
int64_t K;
int64_t OC;
int64_t T_in;
int s0;
int p0;
};
// Mirrors the eval grid of test-backend-ops
static const col2im_case CASES[] = {
{ 16, 32, 197, 8, 0 }, // kernel = 2*stride, DAC upsampling shape
{ 4, 3, 7, 2, 0 },
{ 1, 5, 13, 1, 0 }, // stride 1, no overlap
{ 6, 4, 11, 3, 1 }, // with cropping
{ 2, 3, 9, 3, 0 }, // kernel < stride, gap positions are zeroed
{ 5, 4, 11, 2, 0 }, // kernel not a multiple of stride, alternating overlap
{ 8, 4, 13, 4, 2 }, // padding = stride/2, DAC causal cropping
{ 4, 3, 1, 2, 0 }, // single column, pure kernel unfold
{ 16, 1, 197, 8, 0 }, // OC = 1, mono output stage
{ 1, 5, 13, 3, 0 }, // K = 1 with stride > 1, sparse scatter
{ 8, 2, 3, 2, 5 }, // cropping eats most of the signal, T_out = 2
};
// Input channels of the GEMM, shared by every case
static const int64_t IC = 7;
// Deterministic LCG mapped to [-1, 1]
static uint64_t g_rng = 0x12345678ULL;
static float frand(void) {
g_rng = g_rng * 6364136223846793005ULL + 1442695040888963407ULL;
return (float)((g_rng >> 33) & 0xffffff) / (float)0x800000 - 1.0f;
}
// Read a F32/F16/BF16 tensor back as a flat F32 vector
static std::vector<float> tensor_to_f32(const struct ggml_tensor * t) {
const int64_t n = ggml_nelements(t);
std::vector<float> out(n);
if (t->type == GGML_TYPE_F32) {
memcpy(out.data(), t->data, n * sizeof(float));
} else if (t->type == GGML_TYPE_F16) {
for (int64_t i = 0; i < n; i++) {
out[i] = ggml_fp16_to_fp32(((const ggml_fp16_t *) t->data)[i]);
}
} else {
for (int64_t i = 0; i < n; i++) {
out[i] = ggml_bf16_to_fp32(((const ggml_bf16_t *) t->data)[i]);
}
}
return out;
}
// NMSE of the cropped output against the p0 shifted slice of the full reference
static double nmse_cropped(const float * y, const float * ref, int64_t T_out, int64_t T_ref, int64_t OC, int p0) {
double num = 0.0;
double den = 0.0;
for (int64_t oc = 0; oc < OC; oc++) {
for (int64_t t = 0; t < T_out; t++) {
const double a = y [t + oc * T_out];
const double b = ref[t + p0 + oc * T_ref];
num += (a - b) * (a - b);
den += b * b;
}
}
return num / (den + 1e-30);
}
int main(void) {
int fails = 0;
for (const col2im_case & c : CASES) {
const int64_t T_ref = (c.T_in - 1) * c.s0 + c.K;
const int64_t T_out = T_ref - 2 * c.p0;
struct ggml_init_params params = {
/* .mem_size = */ (size_t) 64 << 20,
/* .mem_base = */ NULL,
/* .no_alloc = */ false,
};
struct ggml_context * ctx = ggml_init(params);
// One logical weight and one logical input feed both paths
struct ggml_tensor * w = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, c.K, c.OC, IC);
struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c.T_in, IC);
for (int64_t i = 0; i < ggml_nelements(w); i++) {
((float *) w->data)[i] = frand();
}
for (int64_t i = 0; i < ggml_nelements(x); i++) {
((float *) x->data)[i] = frand();
}
// Reference path: the native op, uncropped
struct ggml_tensor * y_ref = ggml_conv_transpose_1d(ctx, w, x, c.s0, 0, 1);
// Decomposed path: [K, OC, IC] -> [IC, K, OC] -> [IC, K*OC], k fastest inside each oc block
struct ggml_tensor * w_perm = ggml_cont(ctx, ggml_permute(ctx, w, 1, 2, 0, 3));
w_perm = ggml_reshape_2d(ctx, w_perm, IC, c.K * c.OC);
struct ggml_tensor * x_t = ggml_cont(ctx, ggml_transpose(ctx, x));
struct ggml_tensor * col = ggml_mul_mat(ctx, w_perm, x_t);
struct ggml_tensor * y32 = ggml_col2im_1d(ctx, col, c.s0, (int) c.OC, c.p0);
// Half precision kernels: the same columns cast before the scatter
struct ggml_tensor * y16 = ggml_col2im_1d(ctx, ggml_cast(ctx, col, GGML_TYPE_F16), c.s0, (int) c.OC, c.p0);
struct ggml_tensor * ybf = ggml_col2im_1d(ctx, ggml_cast(ctx, col, GGML_TYPE_BF16), c.s0, (int) c.OC, c.p0);
GGML_ASSERT(y_ref->ne[0] == T_ref && y_ref->ne[1] == c.OC);
GGML_ASSERT(y32->ne[0] == T_out && y32->ne[1] == c.OC);
struct ggml_cgraph * gf = ggml_new_graph(ctx);
ggml_build_forward_expand(gf, y_ref);
ggml_build_forward_expand(gf, y32);
ggml_build_forward_expand(gf, y16);
ggml_build_forward_expand(gf, ybf);
ggml_graph_compute_with_ctx(ctx, gf, 4);
const std::vector<float> f32 = tensor_to_f32(y32);
const std::vector<float> f16 = tensor_to_f32(y16);
const std::vector<float> fbf = tensor_to_f32(ybf);
const float * ref = (const float *) y_ref->data;
const double e32 = nmse_cropped(f32.data(), ref, T_out, T_ref, c.OC, c.p0);
const double e16 = nmse_cropped(f16.data(), ref, T_out, T_ref, c.OC, c.p0);
const double ebf = nmse_cropped(fbf.data(), ref, T_out, T_ref, c.OC, c.p0);
// Same thresholds as test-backend-ops: 1e-7 full precision, 5e-4 half
const bool ok = e32 <= 1e-7 && e16 <= 5e-4 && ebf <= 5e-4;
if (!ok) {
fails++;
}
printf("col2im_1d K=%2d OC=%2d T_in=%3d s0=%d p0=%d: nmse f32=%.2e f16=%.2e bf16=%.2e %s\n",
(int) c.K, (int) c.OC, (int) c.T_in, c.s0, c.p0, e32, e16, ebf, ok ? "OK" : "FAIL");
ggml_free(ctx);
}
printf(fails == 0 ? "all col2im_1d checks passed\n" : "%d col2im_1d checks FAILED\n", fails);
return fails == 0 ? 0 : 1;
}
+66 -47
View File
@@ -617,10 +617,52 @@ struct mtmd_helper_video {
float fps_target = 0.0f;
mtmd_helper_video_info info = {};
struct subprocess_s proc = {};
bool proc_alive = false;
// RAII wrapper for managing subprocess
struct subprocess_handle {
struct subprocess_s proc = {};
bool alive = false;
std::thread feeder;
subprocess_handle() = default;
subprocess_handle(const subprocess_handle &) = delete;
subprocess_handle & operator=(const subprocess_handle &) = delete;
~subprocess_handle() { stop(); }
void stop() {
if (alive) {
subprocess_terminate(&proc);
}
// join before destroy: feeder holds a FILE* from subprocess_stdin;
// subprocess_destroy closes it, so the thread must finish first
if (feeder.joinable()) {
feeder.join();
}
if (alive) {
subprocess_destroy(&proc);
alive = false;
}
}
FILE * stdout_pipe() {
return subprocess_stdout(&proc);
}
// buf is tied to lifetime of mtmd_helper_video, so it's guaranteed to outlive the feeder thread
void start_feeder(const std::vector<uint8_t> & buf) {
feeder = std::thread([this, &buf]() {
FILE * f = subprocess_stdin(&proc);
if (!f) {
return;
}
fwrite(buf.data(), 1, buf.size(), f);
fclose(f);
proc.stdin_file = nullptr; // prevent double-close in subprocess_destroy
});
}
};
subprocess_handle sp;
int32_t current_frame = 0;
std::thread feeder_thread;
std::string prompt_start = "Video:";
int32_t timestamp_interval_ms = 5000; // emit a timestamp text every N ms (0 = disabled)
@@ -630,19 +672,8 @@ struct mtmd_helper_video {
std::string pending_text; // text queued to be returned before the next frame
bool start_emitted = false;
bool is_buf_input() const { return !input_buf.empty(); }
// must run in a separate thread alongside stdout reading to avoid pipe deadlock
void feed_stdin(struct subprocess_s * sp) {
FILE * f = subprocess_stdin(sp);
if (!f) {
LOG_DBG("%s: subprocess has no stdin pipe\n", __func__);
return;
}
LOG_DBG("%s: feeding %zu bytes to stdin\n", __func__, input_buf.size());
size_t written = fwrite(input_buf.data(), 1, input_buf.size(), f);
LOG_DBG("%s: wrote %zu bytes, closing stdin\n", __func__, written);
fclose(f);
bool is_buf_input() const {
return !input_buf.empty();
}
bool probe(float fps_target_arg) {
@@ -661,17 +692,17 @@ struct mtmd_helper_video {
for (size_t i = 0; cmd[i]; i++) { LOG_DBG(" %s", cmd[i]); }
LOG_DBG("\n");
struct subprocess_s fprobe;
subprocess_handle probe_sp;
if (subprocess_create(cmd,
subprocess_option_search_user_path | subprocess_option_inherit_environment,
&fprobe) != 0) {
&probe_sp.proc) != 0) {
LOG_ERR("%s: failed to launch ffprobe\n", __func__);
return false;
}
probe_sp.alive = true;
std::thread probe_feeder;
if (is_buf_input()) {
probe_feeder = std::thread([this, &fprobe]() { feed_stdin(&fprobe); });
probe_sp.start_feeder(input_buf);
}
uint32_t width = 0;
@@ -680,7 +711,7 @@ struct mtmd_helper_video {
float duration = -1.0f;
int32_t n_frames_orig = -1;
char line[256];
FILE * fp = subprocess_stdout(&fprobe);
FILE * fp = probe_sp.stdout_pipe();
while (fgets(line, sizeof(line), fp)) {
char * eq = strchr(line, '=');
@@ -704,13 +735,7 @@ struct mtmd_helper_video {
}
}
if (probe_feeder.joinable()) {
probe_feeder.join();
}
int ret_code;
subprocess_join(&fprobe, &ret_code);
subprocess_destroy(&fprobe);
probe_sp.stop();
if (width == 0 || height == 0 || orig_fps <= 0.0f) {
return false;
@@ -745,6 +770,7 @@ struct mtmd_helper_video {
cmd.push_back(seek_buf);
}
cmd.push_back("-nostdin");
cmd.push_back("-i");
// cache:pipe:0 wraps stdin with a seekable in-memory cache, letting ffmpeg seek
// backwards for container headers (e.g. MP4 moov atom at end of file)
@@ -781,34 +807,27 @@ struct mtmd_helper_video {
int ret = subprocess_create(
cmd.data(),
subprocess_option_search_user_path | subprocess_option_inherit_environment,
&proc);
&sp.proc);
proc_alive = (ret == 0);
LOG_DBG("%s: subprocess_create ret=%d proc_alive=%d\n", __func__, ret, (int)proc_alive);
sp.alive = (ret == 0);
LOG_DBG("%s: subprocess_create ret=%d proc_alive=%d\n", __func__, ret, (int)sp.alive);
if (proc_alive && is_buf_input()) {
if (sp.alive && is_buf_input()) {
LOG_DBG("%s: starting feeder thread for %zu-byte buffer\n", __func__, input_buf.size());
feeder_thread = std::thread([this]() { feed_stdin(&proc); });
sp.start_feeder(input_buf);
}
return proc_alive;
return sp.alive;
}
void stop_ffmpeg() {
if (proc_alive) {
subprocess_terminate(&proc);
subprocess_destroy(&proc);
proc_alive = false;
}
if (feeder_thread.joinable()) {
feeder_thread.join();
}
sp.stop();
}
mtmd_bitmap * read_next_frame() {
if (!proc_alive) return nullptr;
if (!sp.alive) return nullptr;
FILE * fp = subprocess_stdout(&proc);
FILE * fp = sp.stdout_pipe();
const size_t frame_size = (size_t)info.width * info.height * 3;
LOG_DBG("%s: reading frame %d, expecting %zu bytes (%ux%u)\n",
__func__, current_frame, frame_size, info.width, info.height);
@@ -820,7 +839,7 @@ struct mtmd_helper_video {
// clean EOF only if no bytes read yet; partial frame is an error
LOG_DBG("%s: fread returned 0 after %zu/%zu bytes (ferror=%d)\n",
__func__, total_read, frame_size, ferror(fp));
proc_alive = false;
sp.alive = false;
return nullptr;
}
total_read += n;
@@ -842,9 +861,9 @@ struct mtmd_helper_video {
}
LOG_DBG("%s: proc_alive=%d start_emitted=%d current_frame=%d\n",
__func__, (int)proc_alive, (int)start_emitted, current_frame);
__func__, (int)sp.alive, (int)start_emitted, current_frame);
if (!proc_alive) {
if (!sp.alive) {
return (current_frame == 0) ? -2 : -1;
}
+1 -1
View File
@@ -165,7 +165,7 @@ For the full list of features, please refer to [server's changelog](https://gith
| `-cms, --checkpoint-min-step N` | minimum spacing between context checkpoints in tokens (default: 256, 0 = no minimum)<br/>(env: LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT) |
| `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
| `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)<br/>(env: LLAMA_ARG_KV_UNIFIED) |
| `--cache-idle-slots, --no-cache-idle-slots` | save and clear idle slots on new task (default: enabled, requires unified KV and cache-ram)<br/>(env: LLAMA_ARG_CACHE_IDLE_SLOTS) |
| `--cache-idle-slots, --no-cache-idle-slots` | save idle slots to the prompt cache on new task, and clear them when using unified KV (default: enabled, requires cache-ram)<br/>(env: LLAMA_ARG_CACHE_IDLE_SLOTS) |
| `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode |
| `-sp, --special` | special tokens output enabled (default: false) |
+41 -27
View File
@@ -27,6 +27,7 @@
#include <memory>
#include <filesystem>
#include <utility>
#include <fstream>
// fix problem with std::min and std::max
#if defined(_WIN32)
@@ -127,7 +128,11 @@ struct server_slot {
server_prompt prompt;
void prompt_save(server_prompt_cache & prompt_cache) const {
bool prompt_save(server_prompt_cache & prompt_cache) const {
if (prompt.tokens.size() == 0) {
return false;
}
GGML_ASSERT(prompt.data.size() == 0);
const size_t cur_size_tgt = llama_state_seq_get_size_ext(ctx_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
@@ -140,13 +145,15 @@ struct server_slot {
auto * cur = prompt_cache.alloc(prompt, cur_size_tgt, cur_size_dft);
if (cur == nullptr) {
return;
return false;
}
llama_state_seq_get_data_ext(ctx_tgt, cur->data.main.data(), cur_size_tgt, id, LLAMA_STATE_SEQ_FLAGS_NONE);
if (ctx_dft) {
llama_state_seq_get_data_ext(ctx_dft, cur->data.drft.data(), cur_size_dft, id, LLAMA_STATE_SEQ_FLAGS_NONE);
}
return true;
}
bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
@@ -739,17 +746,6 @@ private:
llama_batch_free(batch);
}
void slot_save_and_clear(server_slot & slot) {
if (slot.prompt.n_tokens() == 0) {
return;
}
SLT_INF(slot, "%s", "saving idle slot to prompt cache\n");
SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n");
slot.prompt_save(*prompt_cache);
slot.prompt_clear(false);
prompt_cache->update();
}
void handle_sleeping_state(bool new_state) {
GGML_ASSERT(sleeping != new_state);
if (new_state) {
@@ -1186,14 +1182,17 @@ private:
metrics.init();
if (params_base.cache_idle_slots) {
if (!params_base.kv_unified) {
SRV_WRN("%s", "--cache-idle-slots requires --kv-unified, disabling\n");
params_base.cache_idle_slots = false;
} else if (params_base.cache_ram_mib == 0) {
if (params_base.cache_ram_mib == 0) {
SRV_WRN("%s", "--cache-idle-slots requires --cache-ram, disabling\n");
params_base.cache_idle_slots = false;
} else {
SRV_INF("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
if (params_base.kv_unified) {
SRV_INF("%s", "idle slots will be saved to prompt cache and cleared upon starting a new task\n");
} else {
// without a unified KV cache, clearing a slot frees no reusable room, so we only
// publish a RAM-cache copy of idle slots (their KV stays in VRAM) [TAG_IDLE_SLOT_CLEAR]
SRV_INF("%s", "idle slots will be saved to prompt cache upon starting a new task\n");
}
SRV_DBG("%s", "__TEST_TAG_CACHE_IDLE_SLOTS_ENABLED__\n");
}
}
@@ -1357,8 +1356,6 @@ private:
}
if (ret) {
const auto & tokens = ret->prompt.tokens;
update_cache = update_cache && prompt_cache;
// cache prompts only for completion tasks
@@ -1369,10 +1366,7 @@ private:
const int64_t t_start = ggml_time_us();
// don't save the slot's state if its context is empty
if (tokens.size() > 0) {
ret->prompt_save(*prompt_cache);
}
ret->prompt_save(*prompt_cache);
if (!ret->prompt_load(*prompt_cache, task.tokens)) {
ret->prompt_clear(false);
@@ -2120,9 +2114,19 @@ private:
}
if (params_base.cache_idle_slots) {
for (auto & s : slots) {
if (!s.is_processing()) {
slot_save_and_clear(s);
for (auto & slot : slots) {
if (!slot.is_processing()) {
SLT_INF(slot, "%s", "saving idle slot to prompt cache\n");
if (slot.prompt_save(*prompt_cache)) {
SLT_DBG(slot, "%s", "__TEST_TAG_CACHE_IDLE_SLOT__\n");
prompt_cache->update();
}
if (params_base.kv_unified) {
// [TAG_IDLE_SLOT_CLEAR]
slot.prompt_clear(false);
}
}
}
}
@@ -3716,6 +3720,16 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
if (!params.path_prompts_log_dir.empty()) {
const auto file_path = std::filesystem::path(params.path_prompts_log_dir) / string_format("%012" PRId64 ".txt", ggml_time_ms());
std::ofstream f(file_path);
if (f) {
f << (prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
} else {
SRV_ERR("failed to create %s\n", file_path.string().c_str());
}
}
// process prompt
std::vector<server_tokens> inputs;
@@ -46,10 +46,12 @@ export function llamaCppBuildPlugin(): Plugin {
content = content.replace(/\r/g, '');
content = GUIDE_FOR_FRONTEND + '\n' + content;
content = content.replace(/\/_app\/immutable\/bundle\.[^"]+\.js/g, './bundle.js');
// Keep the Vite hash as a query string so each build busts the browser cache
content = content.replace(/\/_app\/immutable\/bundle\.([^".]+)\.js/g, './bundle.js?$1');
content = content.replace(
/\/_app\/immutable\/assets\/bundle\.[^"]+\.css/g,
'./bundle.css'
/\/_app\/immutable\/assets\/bundle\.([^".]+)\.css/g,
'./bundle.css?$1'
);
content = content.replace(/__sveltekit_[a-z0-9]+/g, '__sveltekit__');
-1
View File
@@ -148,7 +148,6 @@
* {
scrollbar-width: thin;
scrollbar-color: transparent transparent;
transition: scrollbar-color 0.2s ease;
}
*:hover {
+1 -1
View File
@@ -254,7 +254,7 @@
/>
<Sidebar.Provider bind:open={sidebarOpen}>
<div class="flex h-screen w-full">
<div class="flex h-dvh w-full">
<Sidebar.Root variant="floating" class="h-full"
><SidebarNavigation bind:this={chatSidebar} /></Sidebar.Root
>