mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-09 07:16:44 +02:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3ac3c20c96 | |||
| 1e1aca09da | |||
| 7d2b45b4f7 | |||
| 42a0afd594 | |||
| a66d50588b | |||
| 1705d434f6 | |||
| 3b3da01dc2 | |||
| 3ebe862b5d | |||
| 8f83d6c271 | |||
| c2b1518fd4 | |||
| 6a1de6fbf1 | |||
| 715b86a366 |
@@ -53,7 +53,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -59,7 +59,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -85,7 +85,7 @@ RUN mkdir /tmp/neo/ && cd /tmp/neo/ \
|
||||
&& dpkg --install *.deb
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -64,7 +64,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -107,7 +107,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 libtbb12 curl wget ocl-icd-libopencl1 \
|
||||
&& apt-get install -y libgomp1 libtbb12 curl wget ffmpeg ocl-icd-libopencl1 \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -76,7 +76,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -49,7 +49,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \
|
||||
&& apt-get install -y libgomp1 curl ffmpeg libvulkan1 mesa-vulkan-drivers \
|
||||
libglvnd0 libgl1 libglx0 libegl1 libgles2 \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
|
||||
@@ -46,7 +46,7 @@ LABEL org.opencontainers.image.created=$BUILD_DATE \
|
||||
org.opencontainers.image.source=$IMAGE_SOURCE
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 libnuma1 curl \
|
||||
&& apt-get install -y libgomp1 libnuma1 curl ffmpeg \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
|
||||
@@ -35,6 +35,29 @@ env:
|
||||
LLAMA_ARG_LOG_TIMESTAMPS: 1
|
||||
|
||||
jobs:
|
||||
format:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install clang-format 22
|
||||
run: |
|
||||
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key |
|
||||
sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc > /dev/null
|
||||
sudo add-apt-repository -y \
|
||||
"deb http://apt.llvm.org/noble/ llvm-toolchain-noble-22 main"
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y clang-format-22
|
||||
|
||||
- name: Check formatting
|
||||
run: |
|
||||
find ggml/src/ggml-webgpu \
|
||||
-type f \( -name '*.cpp' -o -name '*.hpp' -o -name '*.h' \) \
|
||||
-print0 |
|
||||
xargs -0 clang-format-22 --dry-run --Werror
|
||||
|
||||
macos:
|
||||
runs-on: macos-latest
|
||||
|
||||
|
||||
+2
-2
@@ -2221,8 +2221,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
|
||||
add_opt(common_arg(
|
||||
{"--image", "--audio"}, "FILE",
|
||||
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
|
||||
{"--image", "--audio", "--video"}, "FILE",
|
||||
"path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files\n",
|
||||
[](common_params & params, const std::string & value) {
|
||||
for (const auto & item : parse_csv_row(value)) {
|
||||
params.image.emplace_back(item);
|
||||
|
||||
+1
-1
@@ -571,7 +571,7 @@ struct common_params {
|
||||
struct common_params_model mmproj;
|
||||
bool mmproj_use_gpu = true; // use GPU for multimodal model
|
||||
bool no_mmproj = false; // explicitly disable multimodal model
|
||||
std::vector<std::string> image; // path to image file(s)
|
||||
std::vector<std::string> image; // path to image file(s) ; TODO: change the name to "media"
|
||||
int image_min_tokens = -1;
|
||||
int image_max_tokens = -1;
|
||||
|
||||
|
||||
@@ -789,6 +789,16 @@ class Gemma4UnifiedModel(Gemma4Model):
|
||||
class Gemma4AssistantModel(Gemma4Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, gen = item
|
||||
|
||||
if "masked_embedding" in name:
|
||||
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
|
||||
return None
|
||||
|
||||
return super().filter_tensors(item)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
|
||||
|
||||
+2
-2
@@ -4,8 +4,8 @@ project("ggml" C CXX ASM)
|
||||
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 13)
|
||||
set(GGML_VERSION_PATCH 1)
|
||||
set(GGML_VERSION_MINOR 14)
|
||||
set(GGML_VERSION_PATCH 0)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash {
|
||||
/** Concat **/
|
||||
|
||||
struct ggml_webgpu_concat_pipeline_key {
|
||||
int type;
|
||||
int type;
|
||||
bool src_overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
|
||||
bool operator==(const ggml_webgpu_concat_pipeline_key & other) const {
|
||||
return type == other.type && src_overlap == other.src_overlap;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_concat_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.src_overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
@@ -640,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) {
|
||||
|
||||
inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
|
||||
const uint32_t offset_elems =
|
||||
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type));
|
||||
(uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) /
|
||||
ggml_type_size(K->type));
|
||||
return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u;
|
||||
}
|
||||
|
||||
@@ -651,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K,
|
||||
ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment);
|
||||
}
|
||||
|
||||
inline bool ggml_webgpu_flash_attn_kv_direct(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) {
|
||||
inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q,
|
||||
const ggml_tensor * K,
|
||||
const ggml_tensor * V,
|
||||
uint32_t kv_direct_align) {
|
||||
return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) &&
|
||||
(K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
|
||||
}
|
||||
@@ -667,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co
|
||||
key.dst_type = context.dst->type;
|
||||
key.head_dim_qk = (uint32_t) context.src0->ne[0];
|
||||
key.head_dim_v = (uint32_t) context.src2->ne[0];
|
||||
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = context.src3 != nullptr;
|
||||
key.has_sinks = context.src4 != nullptr;
|
||||
key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align);
|
||||
key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2);
|
||||
key.has_mask = context.src3 != nullptr;
|
||||
key.has_sinks = context.src4 != nullptr;
|
||||
key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f;
|
||||
return key;
|
||||
}
|
||||
@@ -1723,7 +1730,7 @@ class ggml_webgpu_shader_lib {
|
||||
key.type = context.dst->type;
|
||||
key.d_state = (int) context.src0->ne[0];
|
||||
key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) &&
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
ggml_webgpu_tensor_overlap(context.src1, context.src5);
|
||||
|
||||
auto it = ssm_scan_pipelines.find(key);
|
||||
if (it != ssm_scan_pipelines.end()) {
|
||||
@@ -2634,6 +2641,7 @@ class ggml_webgpu_shader_lib {
|
||||
webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
|
||||
ggml_webgpu_concat_pipeline_key key = {};
|
||||
key.type = context.dst->type;
|
||||
key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1);
|
||||
|
||||
auto it = concat_pipelines.find(key);
|
||||
if (it != concat_pipelines.end()) {
|
||||
@@ -2656,11 +2664,17 @@ class ggml_webgpu_shader_lib {
|
||||
GGML_ABORT("Unsupported type for concat shader");
|
||||
}
|
||||
|
||||
if (key.src_overlap) {
|
||||
defines.push_back("SRC_OVERLAP");
|
||||
variant += "_src_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
auto processed = preprocessor.preprocess(wgsl_concat, defines);
|
||||
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
|
||||
auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
decisions->src_overlap = key.src_overlap;
|
||||
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
|
||||
pipeline.context = decisions;
|
||||
concat_pipelines[key] = pipeline;
|
||||
|
||||
@@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
|
||||
uint32_t value,
|
||||
size_t offset,
|
||||
size_t size) {
|
||||
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
|
||||
size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
|
||||
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
|
||||
std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
|
||||
std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) };
|
||||
size_t bytes_per_wg =
|
||||
ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread;
|
||||
uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
|
||||
|
||||
ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t));
|
||||
|
||||
@@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx,
|
||||
shader_lib_ctx.src0 = src;
|
||||
shader_lib_ctx.src1 = nullptr;
|
||||
shader_lib_ctx.dst = dst;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
@@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
uint32_t wg_x, wg_y;
|
||||
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
@@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
uint32_t wg_x, wg_y;
|
||||
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx,
|
||||
@@ -2305,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
uint32_t dim = (uint32_t) dst->op_params[0];
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t) src0->ne[dim]
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1),
|
||||
ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst),
|
||||
};
|
||||
|
||||
ggml_webgpu_shader_lib_context shader_lib_ctx = {};
|
||||
shader_lib_ctx.src0 = src0;
|
||||
shader_lib_ctx.src1 = src1;
|
||||
@@ -2339,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx,
|
||||
shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
|
||||
|
||||
webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
|
||||
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type));
|
||||
uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type));
|
||||
size_t merged_offset = 0;
|
||||
size_t merged_size = 0;
|
||||
if (decisions->src_overlap) {
|
||||
const ggml_webgpu_merged_binding_range merged_range =
|
||||
ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 });
|
||||
merged_offset = merged_range.offset;
|
||||
merged_size = merged_range.size;
|
||||
offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range);
|
||||
offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> params = { ne,
|
||||
offset_src0,
|
||||
offset_src1,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
|
||||
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
|
||||
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
|
||||
(uint32_t) dst->ne[0],
|
||||
(uint32_t) dst->ne[1],
|
||||
(uint32_t) dst->ne[2],
|
||||
(uint32_t) dst->ne[3],
|
||||
dim,
|
||||
(uint32_t) src0->ne[dim] };
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {};
|
||||
if (decisions->src_overlap) {
|
||||
entries.push_back(
|
||||
ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
} else {
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1));
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst));
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
@@ -2673,8 +2695,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s
|
||||
entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst));
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
|
||||
uint32_t wg_x, wg_y;
|
||||
uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
|
||||
compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y);
|
||||
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
|
||||
}
|
||||
|
||||
static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx,
|
||||
@@ -3751,7 +3775,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
|
||||
|
||||
static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
|
||||
// we use the maximum workgroup size for the memset pipeline
|
||||
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup *
|
||||
ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
// Size the bytes_per_thread so that the largest buffer size can be handled
|
||||
ctx->capabilities.memset_bytes_per_thread =
|
||||
CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
|
||||
@@ -4228,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
const uint32_t q_tile =
|
||||
use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE;
|
||||
const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u;
|
||||
const bool kv_direct = use_subgroup_matrix ?
|
||||
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
|
||||
false;
|
||||
const bool kv_direct = use_subgroup_matrix ?
|
||||
ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) :
|
||||
false;
|
||||
const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(
|
||||
capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0],
|
||||
(uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct);
|
||||
|
||||
@@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
let src0_i = params.offset_src0 + src0_index(gid.x);
|
||||
let src1_i = params.offset_src1 + src1_index(gid.x);
|
||||
update(params.offset_dst + gid.x, src0_i, src1_i);
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let threads_per_group = u32(WG_SIZE);
|
||||
let i = gid.x + (num_wg.x * threads_per_group) * gid.y;
|
||||
if (i < params.ne) {
|
||||
let src0_i = params.offset_src0 + src0_index(i);
|
||||
let src1_i = params.offset_src1 + src1_index(i);
|
||||
update(params.offset_dst + i, src0_i, src1_i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,16 @@ struct Params {
|
||||
#define DataType i32
|
||||
#endif
|
||||
|
||||
#ifdef SRC_OVERLAP
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> merged_src: array<DataType>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
#else
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
@@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#endif
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
|
||||
@@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
ni[1] * params.stride_src0_1 +
|
||||
ni[2] * params.stride_src0_2 +
|
||||
ni[3] * params.stride_src0_3;
|
||||
#ifdef SRC_OVERLAP
|
||||
dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i];
|
||||
#else
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
|
||||
#endif
|
||||
} else {
|
||||
ni[params.dim] -= params.src0_nedim;
|
||||
let src_i = ni[0] * params.stride_src1_0 +
|
||||
ni[1] * params.stride_src1_1 +
|
||||
ni[2] * params.stride_src1_2 +
|
||||
ni[3] * params.stride_src1_3;
|
||||
#ifdef SRC_OVERLAP
|
||||
dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i];
|
||||
#else
|
||||
dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q1_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4)
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1)
|
||||
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
|
||||
#else
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
#endif
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let tile_m = block_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let block_k = block_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_0
|
||||
#elif INIT_SRC0_SHMEM_Q4_1
|
||||
let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
let m = f16(dm[1]);
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 20u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
@@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 22u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
// tile_k is defined as 32u, so blocks_k ends up being 1 always
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
#elif INIT_SRC0_SHMEM_Q5_0
|
||||
let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 2u);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
@@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_0
|
||||
#elif INIT_SRC0_SHMEM_Q5_1
|
||||
let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 24u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K / BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
let m = f16(dm[1]);
|
||||
let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u);
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 4u);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
let q_packed = load_u32_at_src0_aligned(q_byte_offset);
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
@@ -277,461 +201,306 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q5_1
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 34u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
#elif INIT_SRC0_SHMEM_Q8_0
|
||||
let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_0
|
||||
#elif INIT_SRC0_SHMEM_Q8_1
|
||||
let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
let m = f16(dm[1]);
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q8_1
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 36u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let m = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
// store NQ(16) weights
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte_i32(q_packed, k);
|
||||
|
||||
let q_val = f16(q_byte) * d + m;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val;
|
||||
}
|
||||
}
|
||||
#elif INIT_SRC0_SHMEM_MXFP4
|
||||
let block_byte_base = src0_idx * 17u;
|
||||
let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
|
||||
// load NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q8_1
|
||||
#endif
|
||||
|
||||
// k-quants
|
||||
#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K)
|
||||
const BLOCK_SIZE = 256u;
|
||||
const NQ = 4u;
|
||||
|
||||
fn store_shmem_kquants(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
|
||||
fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 {
|
||||
return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u);
|
||||
}
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q2_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 84u;
|
||||
let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u;
|
||||
let scales_byte_base = block_byte_base;
|
||||
let qs_byte_base = block_byte_base + 16u;
|
||||
let dm_byte_base = block_byte_base + 80u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
// Use standard thread layout instead of lane/row_group
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
|
||||
let d = f16(d_packed[0]);
|
||||
let dmin = f16(d_packed[1]);
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let chunk = k_in_block / 128u;
|
||||
let pos_in_chunk = k_in_block % 32u;
|
||||
let sub_block = k_in_block / 16u;
|
||||
let shift_phase = (k_in_block % 128u) / 32u;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
// whole 2 bits (4 elems)
|
||||
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
|
||||
let qs_vec4 = vec4<f16>(
|
||||
f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u),
|
||||
f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u),
|
||||
f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u),
|
||||
f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u),
|
||||
);
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block);
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let dl = d * f16(scale & 0xFu);
|
||||
let ml = dmin * f16(scale >> 4u);
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base + 80u);
|
||||
let dmin = load_f16_at_src0(block_byte_base + 82u);
|
||||
store_shmem_kquants(qs_vec4 * dl - ml, elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q3_K
|
||||
let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u;
|
||||
let hmask_byte_base = block_byte_base + 0u;
|
||||
let qs_byte_base = block_byte_base + 32u;
|
||||
let scales_byte_base = block_byte_base + 96u;
|
||||
|
||||
// Decode the element at position k_in_block
|
||||
let block_of_32 = k_in_block / 32u;
|
||||
let pos_in_32 = k_in_block % 32u;
|
||||
let d_all = load_f16_at_src0(block_byte_base + 108u);
|
||||
|
||||
let q_b_idx = (block_of_32 / 4u) * 32u;
|
||||
let shift = (block_of_32 % 4u) * 2u;
|
||||
let k = (pos_in_32 / 16u) * 16u;
|
||||
let l = pos_in_32 % 16u;
|
||||
let chunk = k_in_block / 128u;
|
||||
let pos_in_chunk = k_in_block % 32u;
|
||||
let sub_block = k_in_block / 16u;
|
||||
let shift_phase = (k_in_block % 128u) / 32u;
|
||||
|
||||
let is = k_in_block / 16u;
|
||||
let hmask_block = pos_in_chunk;
|
||||
let hmask_shift_phase = k_in_block / 32u;
|
||||
|
||||
let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u));
|
||||
let sc = get_byte(sc_packed, is % 4u);
|
||||
// low 2 bits (4 elems)
|
||||
let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block);
|
||||
let q_lo2_vec4 = vec4<f16>(
|
||||
f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u),
|
||||
f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u),
|
||||
f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u),
|
||||
f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u)
|
||||
);
|
||||
|
||||
let dl = d * f16(sc & 0xFu);
|
||||
let ml = dmin * f16(sc >> 4u);
|
||||
// high 1 bit (4 elems)
|
||||
let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk);
|
||||
let q_hi1_vec4 = vec4<f16>(
|
||||
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)),
|
||||
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)),
|
||||
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)),
|
||||
f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u))
|
||||
);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
let q_vec4 = q_lo2_vec4 - q_hi1_vec4;
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q2_K
|
||||
let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu;
|
||||
let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u;
|
||||
let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0);
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q3_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 110u;
|
||||
store_shmem_kquants(dl * q_vec4, elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q4_K
|
||||
let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u;
|
||||
let dm_byte_base = block_byte_base + 0u;
|
||||
let scale_byte_base = block_byte_base + 4u;
|
||||
let qs_byte_base = block_byte_base + 16u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
let dmin = f16(dm[1]);
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let chunk = k_in_block / 64u;
|
||||
let pos_in_chunk = (k_in_block % 64u) % 32u;
|
||||
let sub_block = k_in_block / 32u;
|
||||
let shift_phase = sub_block & 1u;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base + 108u);
|
||||
|
||||
// Load and unpack scales
|
||||
let kmask1: u32 = 0x03030303u;
|
||||
let kmask2: u32 = 0x0f0f0f0fu;
|
||||
|
||||
var scale_vals: array<u32, 4>;
|
||||
for (var i: u32 = 0u; i < 4u; i++) {
|
||||
scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i);
|
||||
}
|
||||
|
||||
var tmp: u32 = scale_vals[2];
|
||||
scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
|
||||
scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
|
||||
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
|
||||
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
|
||||
|
||||
// Load hmask and qs arrays
|
||||
var hmask_vals: array<u32, 8>;
|
||||
for (var i: u32 = 0u; i < 8u; i++) {
|
||||
hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i);
|
||||
}
|
||||
|
||||
var qs_vals: array<u32, 16>;
|
||||
for (var i: u32 = 0u; i < 16u; i++) {
|
||||
qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i);
|
||||
}
|
||||
|
||||
let half = k_in_block / 128u; // 0 or 1
|
||||
let pos_in_half = k_in_block % 128u; // 0-127
|
||||
let shift_group = pos_in_half / 32u; // 0-3
|
||||
let pos_in_32 = pos_in_half % 32u; // 0-31
|
||||
let k_group = pos_in_32 / 16u; // 0 or 1
|
||||
let l = pos_in_32 % 16u; // 0-15
|
||||
|
||||
let q_b_idx = half * 32u; // 0 or 32
|
||||
let shift = shift_group * 2u; // 0, 2, 4, 6
|
||||
let k = k_group * 16u; // 0 or 16
|
||||
let is = k_in_block / 16u; // 0-15
|
||||
|
||||
// m increments every 32 elements across entire 256 element block
|
||||
let m_shift = k_in_block / 32u; // 0-7
|
||||
let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128
|
||||
|
||||
let sc = get_byte(scale_vals[is / 4u], is % 4u);
|
||||
let dl = d * (f16(sc) - 32.0);
|
||||
|
||||
let q_idx = q_b_idx + k + l;
|
||||
let hm_idx = k + l;
|
||||
|
||||
let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
|
||||
let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
|
||||
|
||||
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
|
||||
let qs_val = (q_byte >> shift) & 3u;
|
||||
|
||||
let q_val = (f16(qs_val) - f16(hm)) * dl;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q3_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 144u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let dmin = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
// Map k_in_block to loop structure:
|
||||
// Outer loop over 64-element groups (alternating q_b_idx)
|
||||
// Inner loop over 2 shifts per group
|
||||
let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx)
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
// whole 4 bits (4 elems)
|
||||
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
|
||||
let qs_vec4 = vec4<f16>(
|
||||
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
|
||||
);
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
let scale_base = block_byte_base + 4u;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
|
||||
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
if (sub_block < 4u) {
|
||||
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
|
||||
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u));
|
||||
store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q5_K
|
||||
let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u;
|
||||
let dm_byte_base = block_byte_base + 0u;
|
||||
let scale_byte_base = block_byte_base + 4u;
|
||||
let qh_byte_base = block_byte_base + 16u;
|
||||
let qs_byte_base = block_byte_base + 48u;
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base));
|
||||
let d = f16(dm[0]);
|
||||
let dmin = f16(dm[1]);
|
||||
|
||||
let q_val = f16(qs_val) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q4_K
|
||||
let chunk = k_in_block / 64u;
|
||||
let pos_in_chunk = (k_in_block % 64u) % 32u;
|
||||
let sub_block = k_in_block / 32u;
|
||||
let shift_phase = sub_block & 1u;
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q5_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 176u;
|
||||
let qh_block = k_in_block % 32u;
|
||||
let qh_shift_phase = sub_block;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
// low 4 bits (4 elems)
|
||||
let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk);
|
||||
let qs_lo4_vec4 = vec4<f16>(
|
||||
f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu),
|
||||
f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu)
|
||||
);
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let dmin = load_f16_at_src0(block_byte_base + 2u);
|
||||
|
||||
|
||||
// The original loop processes elements in groups of 64
|
||||
// Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
|
||||
// But u increments EVERY 32 elements (after each l loop)
|
||||
let group_of_64 = k_in_block / 64u; // 0-3
|
||||
let pos_in_64 = k_in_block % 64u; // 0-63
|
||||
let shift_group = pos_in_64 / 32u; // 0 or 1
|
||||
let l = pos_in_64 % 32u; // 0-31
|
||||
|
||||
let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96
|
||||
let shift = shift_group * 4u; // 0 or 4
|
||||
let is = k_in_block / 32u; // 0-7
|
||||
|
||||
// u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
|
||||
let u_shift = k_in_block / 32u; // 0-7
|
||||
let u: u32 = 1u << u_shift;
|
||||
// high 1 bit (4 elems)
|
||||
let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block);
|
||||
let qh_vec4 = vec4<f16>(
|
||||
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)),
|
||||
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)),
|
||||
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)),
|
||||
f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u))
|
||||
);
|
||||
|
||||
var sc: u32;
|
||||
var mn: u32;
|
||||
|
||||
let scale_base = block_byte_base + 4u;
|
||||
|
||||
if (is < 4u) {
|
||||
let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u);
|
||||
let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
if (sub_block < 4u) {
|
||||
let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u);
|
||||
let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
|
||||
sc = sc_byte & 63u;
|
||||
mn = min_byte & 63u;
|
||||
} else {
|
||||
let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u);
|
||||
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u);
|
||||
let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u);
|
||||
let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u);
|
||||
sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
|
||||
mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
|
||||
}
|
||||
|
||||
let dl = d * f16(sc);
|
||||
let ml = dmin * f16(mn);
|
||||
|
||||
let q_idx = q_b_idx + l;
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u));
|
||||
store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx);
|
||||
#elif INIT_SRC0_SHMEM_Q6_K
|
||||
let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u;
|
||||
let ql_byte_base = block_byte_base;
|
||||
let qh_byte_base = block_byte_base + 128u;
|
||||
let scales_byte_base = block_byte_base + 192u;
|
||||
let d_byte_base = block_byte_base + 208u;
|
||||
|
||||
let q_byte = get_byte(q_packed, q_idx % 4u);
|
||||
let d = load_f16_at_src0(d_byte_base);
|
||||
|
||||
let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u));
|
||||
let chunk = k_in_block / 128u;
|
||||
let ql_pos_in_chunk = (k_in_block % 128u) % 64u;
|
||||
let qh_pos_in_chunk = (k_in_block % 128u) % 32u;
|
||||
let sub_block = k_in_block / 16u;
|
||||
let ql_shift_phase = (k_in_block % 128u) / 64u;
|
||||
let qh_shift_phase = (k_in_block % 128u) / 32u;
|
||||
|
||||
let qh_byte = get_byte(qh_packed, l % 4u);
|
||||
// low 4 bits (4 elems)
|
||||
let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk);
|
||||
let ql_lo4_vec4 = vec4<u32>(
|
||||
(ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu,
|
||||
(ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu,
|
||||
(ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu,
|
||||
(ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu
|
||||
);
|
||||
|
||||
let qs_val = (q_byte >> shift) & 0xFu;
|
||||
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
|
||||
// hi 2 bits (4 elems)
|
||||
let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk);
|
||||
let qh_hi2_vec4 = vec4<u32>(
|
||||
((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u,
|
||||
((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u,
|
||||
((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u,
|
||||
((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u,
|
||||
);
|
||||
|
||||
let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
|
||||
shmem[elem_idx] = q_val;
|
||||
let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0);
|
||||
|
||||
let scale_byte = scales_byte_base + 1u * sub_block;
|
||||
let scale_word = load_u32_at_src0_aligned(scale_byte);
|
||||
let scale = get_byte_i32(scale_word, scale_byte & 3u);
|
||||
|
||||
store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#endif // INIT_SRC0_SHMEM_Q5_K
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q6_K
|
||||
const BLOCK_SIZE = 256u;
|
||||
const BLOCK_SIZE_BYTES = 210u;
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
|
||||
if (global_m >= params.m || global_k >= params.k) {
|
||||
shmem[elem_idx] = f16(0.0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_k = global_k / BLOCK_SIZE;
|
||||
let k_in_block = global_k % BLOCK_SIZE;
|
||||
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
|
||||
let half = k_in_block / 128u;
|
||||
let pos_in_half = k_in_block % 128u;
|
||||
let quarter = pos_in_half / 32u;
|
||||
let l = pos_in_half % 32u;
|
||||
|
||||
let ql_b_idx = half * 64u;
|
||||
let qh_b_idx = half * 32u;
|
||||
let sc_b_idx = half * 8u;
|
||||
|
||||
// Load only ql13 word needed
|
||||
let ql13_flat = ql_b_idx + l;
|
||||
let ql13 = load_u32_at_src0(block_byte_base + ql13_flat);
|
||||
let ql13_b = get_byte(ql13, 0u);
|
||||
|
||||
// Load only ql24 word needed
|
||||
let ql24_flat = ql_b_idx + l + 32u;
|
||||
let ql24 = load_u32_at_src0(block_byte_base + ql24_flat);
|
||||
let ql24_b = get_byte(ql24, 0u);
|
||||
|
||||
// Load only qh word needed
|
||||
let qh_flat = qh_b_idx + l;
|
||||
let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat);
|
||||
let qh_b = get_byte(qh, 0u);
|
||||
|
||||
let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
|
||||
let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
|
||||
let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
|
||||
let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
|
||||
|
||||
// Load only the scale word needed
|
||||
let is = l / 16u;
|
||||
let sc_idx = sc_b_idx + is + quarter * 2u;
|
||||
let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx);
|
||||
let sc_val = get_byte_i32(sc, 0u);
|
||||
|
||||
let d = load_f16_at_src0(block_byte_base + 208u);
|
||||
|
||||
var q_val: f16;
|
||||
if (quarter == 0u) {
|
||||
q_val = q1;
|
||||
} else if (quarter == 1u) {
|
||||
q_val = q2;
|
||||
} else if (quarter == 2u) {
|
||||
q_val = q3;
|
||||
} else {
|
||||
q_val = q4;
|
||||
}
|
||||
|
||||
shmem[elem_idx] = d * f16(sc_val) * q_val;
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q6_K
|
||||
#endif // k-quants
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_IQ4_NL
|
||||
const BLOCK_SIZE = 32u;
|
||||
@@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_IQ3_S
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_MXFP4
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 17u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
const NQ = 16u;
|
||||
const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q
|
||||
const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed)
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let blck_idx = i / BLOCK_SIZE;
|
||||
let block_offset = (i % BLOCK_SIZE) / NQ;
|
||||
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD;
|
||||
|
||||
let tile_m = blck_idx / BLOCKS_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let block_k = blck_idx % BLOCKS_K;
|
||||
let global_block_k = k_outer / BLOCK_SIZE + block_k;
|
||||
|
||||
if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) {
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0);
|
||||
let e = ldexp(1.0, i32(eu8) - 128);
|
||||
|
||||
// store NQ(16) weights
|
||||
for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) {
|
||||
|
||||
let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP;
|
||||
let q_packed = load_u32_at_src0(q_byte_offset);
|
||||
|
||||
for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e;
|
||||
let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e;
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo);
|
||||
shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_MXFP4
|
||||
|
||||
@@ -43,12 +43,14 @@ struct Params {
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
fn main(
|
||||
@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let threads_per_group = u32(WG_SIZE);
|
||||
var i = gid.x + (num_wg.x * threads_per_group) * gid.y;
|
||||
if (i >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
|
||||
@@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE {
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>) {
|
||||
let threads_per_group = u32(WG_SIZE);
|
||||
let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y;
|
||||
if (flat_i >= params.ne) {
|
||||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
var i = flat_i;
|
||||
let ne2 = params.ne2;
|
||||
#ifdef DIAG
|
||||
let ne1 = params.ne0;
|
||||
@@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
#ifdef INPLACE
|
||||
src[params.offset_src + src_idx] = res;
|
||||
#else
|
||||
dst[params.offset_dst + gid.x] = res;
|
||||
dst[params.offset_dst + flat_i] = res;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -538,6 +538,8 @@ class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
MASKED_EMBD_CENTROIDS= auto()
|
||||
MASKED_EMBD_ORDERING = auto()
|
||||
TOKEN_TYPES = auto()
|
||||
POS_EMBD = auto()
|
||||
OUTPUT = auto()
|
||||
@@ -1087,6 +1089,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
|
||||
MODEL_TENSOR.TOKEN_TYPES: "token_types",
|
||||
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: "masked_embd_centroids",
|
||||
MODEL_TENSOR.MASKED_EMBD_ORDERING: "masked_embd_ordering",
|
||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
@@ -2586,6 +2590,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.GEMMA4_ASSISTANT: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.MASKED_EMBD_CENTROIDS,
|
||||
MODEL_TENSOR.MASKED_EMBD_ORDERING,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.NEXTN_PROJ_PRE,
|
||||
MODEL_TENSOR.NEXTN_PROJ_POST,
|
||||
|
||||
@@ -37,6 +37,14 @@ class TensorNameMap:
|
||||
"model.embed", # talkie
|
||||
),
|
||||
|
||||
# Masked embeddings
|
||||
MODEL_TENSOR.MASKED_EMBD_CENTROIDS: (
|
||||
"masked_embedding.centroids", # gemma-4 E2B/E4B assistants
|
||||
),
|
||||
MODEL_TENSOR.MASKED_EMBD_ORDERING: (
|
||||
"masked_embedding.token_ordering", # gemma-4 E2B/E4B assistants
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
MODEL_TENSOR.TOKEN_TYPES: (
|
||||
"embeddings.token_type_embeddings", # bert nomic-bert
|
||||
|
||||
@@ -1 +1 @@
|
||||
1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3
|
||||
7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1
|
||||
|
||||
@@ -559,6 +559,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
|
||||
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
|
||||
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
|
||||
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
|
||||
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
|
||||
};
|
||||
|
||||
// declare information about the model weight tensors:
|
||||
@@ -783,6 +785,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
// latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU
|
||||
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
|
||||
{LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
||||
@@ -566,8 +566,11 @@ enum llm_tensor {
|
||||
LLM_TENSOR_NEXTN_HNORM,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
|
||||
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
|
||||
LLM_TENSOR_MASKED_EMBD_CENTROIDS,
|
||||
LLM_TENSOR_MASKED_EMBD_ORDERING,
|
||||
};
|
||||
|
||||
|
||||
enum llm_tensor_layer {
|
||||
LLM_TENSOR_LAYER_INPUT,
|
||||
LLM_TENSOR_LAYER_REPEATING,
|
||||
|
||||
+13
-4
@@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
@@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
@@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
if (self_kq_mask && self_kq_mask->buffer) {
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
|
||||
@@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
}
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -39,6 +39,9 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) {
|
||||
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
|
||||
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED);
|
||||
create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
const int64_t n_embd_backbone = hparams.n_embd_inp();
|
||||
nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0);
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include "mtmd.h"
|
||||
#include "mtmd-helper.h"
|
||||
|
||||
int main(void) {
|
||||
printf("\n\nTesting libmtmd C API...\n");
|
||||
@@ -17,6 +18,11 @@ int main(void) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// simple test for the helper
|
||||
size_t n_tokens_total = mtmd_helper_get_n_tokens(chunks);
|
||||
printf("Total tokens in chunks: %zu\n", n_tokens_total);
|
||||
assert(n_tokens_total > 0);
|
||||
|
||||
size_t n_chunks = mtmd_input_chunks_size(chunks);
|
||||
printf("Number of chunks: %zu\n", n_chunks);
|
||||
assert(n_chunks > 0);
|
||||
|
||||
+19
-3
@@ -128,7 +128,18 @@ struct cli_context {
|
||||
console::spinner::start();
|
||||
server_task_result_ptr result = rd.next(should_stop);
|
||||
|
||||
console::spinner::stop();
|
||||
while (true) {
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial && res_partial->is_begin) {
|
||||
// this is the "send 200 status to client" signal in streaming mode
|
||||
// skip, do not stop the spinner
|
||||
result = rd.next(should_stop);
|
||||
} else {
|
||||
console::spinner::stop();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string curr_content;
|
||||
bool is_thinking = false;
|
||||
|
||||
@@ -224,7 +235,7 @@ struct cli_context {
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 7> cmds = {
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
@@ -232,6 +243,7 @@ static const std::array<std::string_view, 7> cmds = {
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
@@ -446,6 +458,9 @@ int llama_cli(int argc, char ** argv) {
|
||||
if (inf.has_inp_audio) {
|
||||
console::log(" /audio <file> add an audio file\n");
|
||||
}
|
||||
if (inf.has_inp_video) {
|
||||
console::log(" /video <file> add a video file\n");
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
// interactive loop
|
||||
@@ -542,7 +557,8 @@ int llama_cli(int argc, char ** argv) {
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# mtmd
|
||||
|
||||
set(MTMD_VIDEO ON CACHE BOOL "enable video support in mtmd (requires ffmpeg binary in PATH)")
|
||||
# TODO: add MTMD_VIDEO_METHOD in the future to select between ffmpeg and other backends
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(mtmd
|
||||
@@ -63,6 +66,10 @@ target_include_directories(mtmd PRIVATE ../..)
|
||||
target_include_directories(mtmd PRIVATE ../../vendor)
|
||||
target_compile_features (mtmd PRIVATE cxx_std_17)
|
||||
|
||||
if (MTMD_VIDEO)
|
||||
target_compile_definitions(mtmd PRIVATE MTMD_VIDEO)
|
||||
endif()
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(mtmd PRIVATE LLAMA_BUILD)
|
||||
|
||||
+14
-5
@@ -77,6 +77,7 @@ struct mtmd_cli_context {
|
||||
int n_batch;
|
||||
|
||||
mtmd::bitmaps bitmaps;
|
||||
std::vector<mtmd_helper::video_ptr> videos;
|
||||
|
||||
// chat template
|
||||
common_chat_templates_ptr tmpls;
|
||||
@@ -166,11 +167,14 @@ struct mtmd_cli_context {
|
||||
}
|
||||
|
||||
bool load_media(const std::string & fname) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_file(ctx_vision.get(), fname.c_str(), false));
|
||||
if (!bmp.ptr) {
|
||||
auto res = mtmd_helper_bitmap_init_from_file(ctx_vision.get(), fname.c_str(), false);
|
||||
if (!res.bitmap) {
|
||||
return false;
|
||||
}
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
bitmaps.entries.emplace_back(res.bitmap);
|
||||
if (res.video_ctx) {
|
||||
videos.emplace_back(res.video_ctx);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
@@ -253,6 +257,7 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
|
||||
}
|
||||
|
||||
ctx.bitmaps.entries.clear();
|
||||
ctx.videos.clear();
|
||||
|
||||
llama_pos new_n_past;
|
||||
if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(),
|
||||
@@ -373,6 +378,9 @@ int main(int argc, char ** argv) {
|
||||
if (mtmd_support_audio(ctx.ctx_vision.get())) {
|
||||
LOG("\n /audio <path> load an audio");
|
||||
}
|
||||
if (mtmd_helper_support_video(ctx.ctx_vision.get())) {
|
||||
LOG("\n /video <path> load a video");
|
||||
}
|
||||
LOG("\n /clear clear the chat history");
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
@@ -407,14 +415,15 @@ int main(int argc, char ** argv) {
|
||||
g_is_generating = true;
|
||||
bool is_image = line == "/image" || line.find("/image ") == 0;
|
||||
bool is_audio = line == "/audio" || line.find("/audio ") == 0;
|
||||
if (is_image || is_audio) {
|
||||
bool is_video = line == "/video" || line.find("/video ") == 0;
|
||||
if (is_image || is_audio || is_video) {
|
||||
if (line.size() < 8) {
|
||||
LOG_ERR("ERR: Missing media filename\n");
|
||||
continue;
|
||||
}
|
||||
std::string media_path = line.substr(7);
|
||||
if (ctx.load_media(media_path)) {
|
||||
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : "audio");
|
||||
LOG("%s %s loaded\n", media_path.c_str(), is_image ? "image" : is_audio ? "audio" : "video");
|
||||
content += mtmd_default_marker();
|
||||
}
|
||||
// else, error is already printed by libmtmd
|
||||
|
||||
+490
-16
@@ -36,6 +36,11 @@
|
||||
#error "mtmd-helper is a public library outside of mtmd. it must not include internal headers"
|
||||
#endif
|
||||
|
||||
#ifdef MTMD_VIDEO
|
||||
#include "sheredom/subprocess.h"
|
||||
#include <thread>
|
||||
#endif
|
||||
|
||||
//
|
||||
// internal logging functions
|
||||
//
|
||||
@@ -79,6 +84,7 @@ struct mtmd_helper_logger {
|
||||
}
|
||||
} g_logger;
|
||||
|
||||
#define LOG_DBG(...) g_logger.log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
|
||||
#define LOG_INF(...) g_logger.log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
||||
#define LOG_WRN(...) g_logger.log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
||||
#define LOG_ERR(...) g_logger.log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||
@@ -478,42 +484,94 @@ static bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int
|
||||
|
||||
} // namespace audio_helpers
|
||||
|
||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder) {
|
||||
// Computes FNV-1a hash of the data
|
||||
static std::string fnv_hash(const uint8_t * data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
hash ^= data[i];
|
||||
hash *= fnv_prime;
|
||||
}
|
||||
return std::to_string(hash);
|
||||
}
|
||||
|
||||
mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder) {
|
||||
// calculate the hash if needed
|
||||
std::string id;
|
||||
mtmd_bitmap * result = nullptr;
|
||||
|
||||
if (!placeholder) {
|
||||
id = fnv_hash(buf, len);
|
||||
}
|
||||
|
||||
if (audio_helpers::is_audio_file((const char *)buf, len)) {
|
||||
std::vector<float> pcmf32;
|
||||
const int sample_rate = mtmd_get_audio_sample_rate(ctx);
|
||||
if (sample_rate < 0) {
|
||||
LOG_ERR("This model does not support audio input\n");
|
||||
return nullptr;
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
if (!audio_helpers::decode_audio_from_buf(buf, len, sample_rate, pcmf32)) {
|
||||
LOG_ERR("Unable to read WAV audio file from buffer\n");
|
||||
return nullptr;
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
return mtmd_bitmap_init_from_audio(pcmf32.size(), placeholder ? nullptr : pcmf32.data());
|
||||
result = mtmd_bitmap_init_from_audio(pcmf32.size(), placeholder ? nullptr : pcmf32.data());
|
||||
mtmd_bitmap_set_id(result, id.empty() ? nullptr : id.c_str());
|
||||
return {result, nullptr};
|
||||
}
|
||||
|
||||
// otherwise, we assume it's an image
|
||||
mtmd_bitmap * result = nullptr;
|
||||
{
|
||||
if (!result) {
|
||||
int nx, ny, nc;
|
||||
auto * data = stbi_load_from_memory(buf, len, &nx, &ny, &nc, 3);
|
||||
if (!data) {
|
||||
LOG_ERR("%s: failed to decode image bytes\n", __func__);
|
||||
return nullptr;
|
||||
if (data) {
|
||||
result = mtmd_bitmap_init(nx, ny, placeholder ? nullptr : data);
|
||||
mtmd_bitmap_set_id(result, id.empty() ? nullptr : id.c_str());
|
||||
stbi_image_free(data);
|
||||
return {result, nullptr};
|
||||
}
|
||||
result = mtmd_bitmap_init(nx, ny, placeholder ? nullptr : data);
|
||||
stbi_image_free(data);
|
||||
// otherwise, fallthrough to video decoding (if supported)
|
||||
}
|
||||
return result;
|
||||
|
||||
// last try: load as video
|
||||
#ifdef MTMD_VIDEO
|
||||
if (!result) {
|
||||
auto params = mtmd_helper_video_init_params_default();
|
||||
auto video_ctx = mtmd_helper_video_init_from_buf(ctx, buf, len, params);
|
||||
if (!video_ctx) {
|
||||
LOG_ERR("%s: failed to decode buffer as either image/audio/video\n", __func__);
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
result = mtmd_bitmap_init_lazy(ctx,
|
||||
id.empty() ? nullptr : id.c_str(),
|
||||
video_ctx,
|
||||
[](size_t, void * user_data, mtmd_bitmap ** out_bitmap, char ** out_text) -> int {
|
||||
auto * vctx = static_cast<mtmd_helper_video *>(user_data);
|
||||
char * text = nullptr;
|
||||
int ret = mtmd_helper_video_read_next(vctx, out_bitmap, &text);
|
||||
*out_text = text; // heap-allocated by read_next; freed automatically by mtmd
|
||||
return ret;
|
||||
});
|
||||
return {result, video_ctx};
|
||||
}
|
||||
#else
|
||||
if (!result) {
|
||||
LOG_ERR("%s: failed to decode buffer as either image or audio (video support not compiled in)\n", __func__);
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
#endif
|
||||
|
||||
// should not reach here
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) {
|
||||
mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) {
|
||||
std::vector<unsigned char> buf;
|
||||
FILE * f = fopen(fname, "rb");
|
||||
if (!f) {
|
||||
LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno));
|
||||
return nullptr;
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
fseek(f, 0, SEEK_END);
|
||||
@@ -522,7 +580,7 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
|
||||
if (file_size < 0) {
|
||||
LOG_ERR("Failed to get file size of %s\n", fname);
|
||||
fclose(f);
|
||||
return nullptr;
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
buf.resize(file_size);
|
||||
|
||||
@@ -530,9 +588,425 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
|
||||
fclose(f);
|
||||
if (n_read != (size_t)file_size) {
|
||||
LOG_ERR("Failed to read entire file %s", fname);
|
||||
return nullptr;
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
|
||||
return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size(), placeholder);
|
||||
}
|
||||
|
||||
bool mtmd_helper_support_video(mtmd_context * ctx) {
|
||||
#ifdef MTMD_VIDEO
|
||||
return mtmd_support_vision(ctx);
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// Video input helpers
|
||||
//
|
||||
|
||||
#ifdef MTMD_VIDEO
|
||||
|
||||
struct mtmd_helper_video {
|
||||
mtmd_context * mctx;
|
||||
std::string path;
|
||||
std::vector<uint8_t> input_buf; // non-empty when initialized from buffer
|
||||
std::string ffmpeg_bin;
|
||||
std::string ffprobe_bin;
|
||||
float fps_target = 0.0f;
|
||||
mtmd_helper_video_info info = {};
|
||||
|
||||
struct subprocess_s proc = {};
|
||||
bool proc_alive = false;
|
||||
int32_t current_frame = 0;
|
||||
std::thread feeder_thread;
|
||||
|
||||
std::string prompt_start = "Video:";
|
||||
int32_t timestamp_interval_ms = 5000; // emit a timestamp text every N ms (0 = disabled)
|
||||
float next_timestamp_ms = 0.0f; // next elapsed-ms threshold at which to emit
|
||||
|
||||
std::vector<uint8_t> frame_buf;
|
||||
std::string pending_text; // text queued to be returned before the next frame
|
||||
bool start_emitted = false;
|
||||
|
||||
bool is_buf_input() const { return !input_buf.empty(); }
|
||||
|
||||
// must run in a separate thread alongside stdout reading to avoid pipe deadlock
|
||||
void feed_stdin(struct subprocess_s * sp) {
|
||||
FILE * f = subprocess_stdin(sp);
|
||||
if (!f) {
|
||||
LOG_DBG("%s: subprocess has no stdin pipe\n", __func__);
|
||||
return;
|
||||
}
|
||||
LOG_DBG("%s: feeding %zu bytes to stdin\n", __func__, input_buf.size());
|
||||
size_t written = fwrite(input_buf.data(), 1, input_buf.size(), f);
|
||||
LOG_DBG("%s: wrote %zu bytes, closing stdin\n", __func__, written);
|
||||
fclose(f);
|
||||
}
|
||||
|
||||
bool probe(float fps_target_arg) {
|
||||
const char * input_arg = is_buf_input() ? "pipe:0" : path.c_str();
|
||||
const char * cmd[] = {
|
||||
ffprobe_bin.c_str(),
|
||||
"-v", "quiet",
|
||||
"-show_entries", "stream=width,height,r_frame_rate,nb_frames,duration",
|
||||
"-select_streams", "v:0",
|
||||
"-of", "default=noprint_wrappers=1",
|
||||
input_arg,
|
||||
nullptr,
|
||||
};
|
||||
|
||||
LOG_DBG("%s: launching:", __func__);
|
||||
for (size_t i = 0; cmd[i]; i++) { LOG_DBG(" %s", cmd[i]); }
|
||||
LOG_DBG("\n");
|
||||
|
||||
struct subprocess_s fprobe;
|
||||
if (subprocess_create(cmd,
|
||||
subprocess_option_search_user_path | subprocess_option_inherit_environment,
|
||||
&fprobe) != 0) {
|
||||
LOG_ERR("%s: failed to launch ffprobe\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::thread probe_feeder;
|
||||
if (is_buf_input()) {
|
||||
probe_feeder = std::thread([this, &fprobe]() { feed_stdin(&fprobe); });
|
||||
}
|
||||
|
||||
uint32_t width = 0;
|
||||
uint32_t height = 0;
|
||||
float orig_fps = 0.0f;
|
||||
float duration = -1.0f;
|
||||
int32_t n_frames_orig = -1;
|
||||
char line[256];
|
||||
FILE * fp = subprocess_stdout(&fprobe);
|
||||
|
||||
while (fgets(line, sizeof(line), fp)) {
|
||||
char * eq = strchr(line, '=');
|
||||
if (!eq) continue;
|
||||
*eq = '\0';
|
||||
const char * key = line;
|
||||
const char * val = eq + 1;
|
||||
char * nl = (char *)strchr(val, '\n');
|
||||
if (nl) *nl = '\0';
|
||||
|
||||
if (strcmp(key, "width") == 0) {
|
||||
width = (uint32_t)atoi(val);
|
||||
} else if (strcmp(key, "height") == 0) {
|
||||
height = (uint32_t)atoi(val);
|
||||
} else if (strcmp(key, "r_frame_rate") == 0) {
|
||||
orig_fps = parse_rational(val);
|
||||
} else if (strcmp(key, "nb_frames") == 0 && strcmp(val, "N/A") != 0) {
|
||||
n_frames_orig = atoi(val);
|
||||
} else if (strcmp(key, "duration") == 0 && strcmp(val, "N/A") != 0) {
|
||||
duration = (float)atof(val);
|
||||
}
|
||||
}
|
||||
|
||||
if (probe_feeder.joinable()) {
|
||||
probe_feeder.join();
|
||||
}
|
||||
|
||||
int ret_code;
|
||||
subprocess_join(&fprobe, &ret_code);
|
||||
subprocess_destroy(&fprobe);
|
||||
|
||||
if (width == 0 || height == 0 || orig_fps <= 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (duration < 0.0f && n_frames_orig > 0) {
|
||||
duration = (float)n_frames_orig / orig_fps;
|
||||
}
|
||||
|
||||
fps_target = fps_target_arg > 0.0f ? fps_target_arg : orig_fps;
|
||||
info.width = width;
|
||||
info.height = height;
|
||||
info.fps = fps_target;
|
||||
LOG_DBG("%s: %ux%u fps=%.2f duration=%.2fs n_frames=%d\n",
|
||||
__func__, width, height, fps_target, duration, info.n_frames);
|
||||
info.n_frames = duration > 0.0f ? (int32_t)(duration * fps_target + 0.5f) : -1;
|
||||
frame_buf.resize((size_t)width * height * 3);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool start_ffmpeg(float seek_seconds) {
|
||||
char seek_buf[64];
|
||||
char fps_buf[64];
|
||||
|
||||
std::vector<const char *> cmd;
|
||||
cmd.push_back(ffmpeg_bin.c_str());
|
||||
|
||||
if (!is_buf_input() && seek_seconds > 0.0f) {
|
||||
// input-side seek: fast, keyframe-accurate; only valid for seekable file inputs
|
||||
snprintf(seek_buf, sizeof(seek_buf), "%.6f", seek_seconds);
|
||||
cmd.push_back("-ss");
|
||||
cmd.push_back(seek_buf);
|
||||
}
|
||||
|
||||
cmd.push_back("-i");
|
||||
// cache:pipe:0 wraps stdin with a seekable in-memory cache, letting ffmpeg seek
|
||||
// backwards for container headers (e.g. MP4 moov atom at end of file)
|
||||
cmd.push_back(is_buf_input() ? "cache:pipe:0" : path.c_str());
|
||||
|
||||
if (seek_seconds > 0.0f && is_buf_input()) {
|
||||
// output-side seek: frame-accurate but decodes and discards frames up to seek point
|
||||
snprintf(seek_buf, sizeof(seek_buf), "%.6f", seek_seconds);
|
||||
cmd.push_back("-ss");
|
||||
cmd.push_back(seek_buf);
|
||||
}
|
||||
|
||||
if (fps_target > 0.0f) {
|
||||
snprintf(fps_buf, sizeof(fps_buf), "fps=%.6f", fps_target);
|
||||
cmd.push_back("-vf");
|
||||
cmd.push_back(fps_buf);
|
||||
}
|
||||
|
||||
cmd.push_back("-f");
|
||||
cmd.push_back("rawvideo");
|
||||
cmd.push_back("-pix_fmt");
|
||||
cmd.push_back("rgb24");
|
||||
cmd.push_back("pipe:1");
|
||||
cmd.push_back("-loglevel");
|
||||
cmd.push_back("error");
|
||||
cmd.push_back(nullptr);
|
||||
|
||||
LOG_DBG("%s: launching:", __func__);
|
||||
for (size_t i = 0; cmd[i]; i++) {
|
||||
LOG_DBG(" %s", cmd[i]);
|
||||
}
|
||||
LOG_DBG("\n");
|
||||
|
||||
int ret = subprocess_create(
|
||||
cmd.data(),
|
||||
subprocess_option_search_user_path | subprocess_option_inherit_environment,
|
||||
&proc);
|
||||
|
||||
proc_alive = (ret == 0);
|
||||
LOG_DBG("%s: subprocess_create ret=%d proc_alive=%d\n", __func__, ret, (int)proc_alive);
|
||||
|
||||
if (proc_alive && is_buf_input()) {
|
||||
LOG_DBG("%s: starting feeder thread for %zu-byte buffer\n", __func__, input_buf.size());
|
||||
feeder_thread = std::thread([this]() { feed_stdin(&proc); });
|
||||
}
|
||||
|
||||
return proc_alive;
|
||||
}
|
||||
|
||||
void stop_ffmpeg() {
|
||||
if (proc_alive) {
|
||||
subprocess_terminate(&proc);
|
||||
subprocess_destroy(&proc);
|
||||
proc_alive = false;
|
||||
}
|
||||
if (feeder_thread.joinable()) {
|
||||
feeder_thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
mtmd_bitmap * read_next_frame() {
|
||||
if (!proc_alive) return nullptr;
|
||||
|
||||
FILE * fp = subprocess_stdout(&proc);
|
||||
const size_t frame_size = (size_t)info.width * info.height * 3;
|
||||
LOG_DBG("%s: reading frame %d, expecting %zu bytes (%ux%u)\n",
|
||||
__func__, current_frame, frame_size, info.width, info.height);
|
||||
|
||||
size_t total_read = 0;
|
||||
while (total_read < frame_size) {
|
||||
size_t n = fread(frame_buf.data() + total_read, 1, frame_size - total_read, fp);
|
||||
if (n == 0) {
|
||||
// clean EOF only if no bytes read yet; partial frame is an error
|
||||
LOG_DBG("%s: fread returned 0 after %zu/%zu bytes (ferror=%d)\n",
|
||||
__func__, total_read, frame_size, ferror(fp));
|
||||
proc_alive = false;
|
||||
return nullptr;
|
||||
}
|
||||
total_read += n;
|
||||
}
|
||||
|
||||
LOG_DBG("%s: frame %d read OK\n", __func__, current_frame);
|
||||
current_frame++;
|
||||
return mtmd_bitmap_init(info.width, info.height, frame_buf.data());
|
||||
}
|
||||
|
||||
int32_t read_next(mtmd_bitmap ** out_bitmap, char ** out_text) {
|
||||
*out_bitmap = nullptr;
|
||||
*out_text = nullptr;
|
||||
|
||||
if (!pending_text.empty()) {
|
||||
*out_text = strdup(pending_text.c_str());
|
||||
pending_text.clear();
|
||||
return *out_text ? 0 : -2;
|
||||
}
|
||||
|
||||
LOG_DBG("%s: proc_alive=%d start_emitted=%d current_frame=%d\n",
|
||||
__func__, (int)proc_alive, (int)start_emitted, current_frame);
|
||||
|
||||
if (!proc_alive) {
|
||||
return (current_frame == 0) ? -2 : -1;
|
||||
}
|
||||
|
||||
if (!start_emitted) {
|
||||
start_emitted = true;
|
||||
if (!prompt_start.empty()) {
|
||||
*out_text = strdup(prompt_start.c_str());
|
||||
return *out_text ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
mtmd_bitmap * frame = read_next_frame();
|
||||
if (!frame) return -1;
|
||||
*out_bitmap = frame;
|
||||
|
||||
if (timestamp_interval_ms > 0) {
|
||||
// current_frame was already incremented by read_next_frame(); undo for elapsed calc
|
||||
float elapsed_ms = (float)(current_frame - 1) / info.fps * 1000.0f;
|
||||
if (elapsed_ms >= next_timestamp_ms) {
|
||||
char ts_buf[32];
|
||||
float elapsed_s = elapsed_ms / 1000.0f;
|
||||
int minutes = (int)(elapsed_s / 60);
|
||||
float seconds = elapsed_s - minutes * 60.0f;
|
||||
snprintf(ts_buf, sizeof(ts_buf), "[%dm%.2fs]", minutes, seconds);
|
||||
pending_text = ts_buf;
|
||||
next_timestamp_ms += (float)timestamp_interval_ms;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static float parse_rational(const char * s) {
|
||||
int num = 0, den = 1;
|
||||
if (sscanf(s, "%d/%d", &num, &den) == 2 && den > 0) {
|
||||
return (float)num / (float)den;
|
||||
}
|
||||
float val;
|
||||
if (sscanf(s, "%f", &val) == 1) {
|
||||
return val;
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
mtmd_helper_video_init_params mtmd_helper_video_init_params_default() {
|
||||
return {
|
||||
/* fps_target */ 4.0f,
|
||||
/* ffmpeg_bin_dir */ nullptr,
|
||||
/* timestamp_interval_ms */ 5000,
|
||||
};
|
||||
}
|
||||
|
||||
static std::string video_resolve_bin(const char * bin_dir, const char * name) {
|
||||
if (!bin_dir || bin_dir[0] == '\0') {
|
||||
return name; // rely on PATH
|
||||
}
|
||||
std::string result = bin_dir;
|
||||
char last = result.back();
|
||||
if (last != '/' && last != '\\') {
|
||||
#ifdef _WIN32
|
||||
result += '\\';
|
||||
#else
|
||||
result += '/';
|
||||
#endif
|
||||
}
|
||||
result += name;
|
||||
#ifdef _WIN32
|
||||
result += ".exe";
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
mtmd_helper_video * mtmd_helper_video_init(
|
||||
mtmd_context * mctx,
|
||||
const char * path,
|
||||
mtmd_helper_video_init_params params) {
|
||||
#ifdef MTMD_VIDEO
|
||||
auto * ctx = new mtmd_helper_video();
|
||||
|
||||
ctx->mctx = mctx;
|
||||
ctx->path = path;
|
||||
ctx->ffmpeg_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffmpeg");
|
||||
ctx->ffprobe_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffprobe");
|
||||
ctx->timestamp_interval_ms = params.timestamp_interval_ms;
|
||||
|
||||
if (!ctx->probe(params.fps_target)) {
|
||||
LOG_ERR("%s: ffprobe failed for '%s' (is ffprobe in PATH?)\n", __func__, path);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ctx->start_ffmpeg(0.0f)) {
|
||||
LOG_ERR("%s: failed to start ffmpeg for '%s' (is ffmpeg in PATH?)\n", __func__, path);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
#else
|
||||
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
mtmd_helper_video * mtmd_helper_video_init_from_buf(
|
||||
mtmd_context * mctx,
|
||||
const unsigned char * buf, size_t len,
|
||||
mtmd_helper_video_init_params params) {
|
||||
#ifdef MTMD_VIDEO
|
||||
auto * ctx = new mtmd_helper_video();
|
||||
|
||||
ctx->mctx = mctx;
|
||||
ctx->input_buf.assign(buf, buf + len);
|
||||
ctx->ffmpeg_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffmpeg");
|
||||
ctx->ffprobe_bin = video_resolve_bin(params.ffmpeg_bin_dir, "ffprobe");
|
||||
ctx->timestamp_interval_ms = params.timestamp_interval_ms;
|
||||
|
||||
if (!ctx->probe(params.fps_target)) {
|
||||
LOG_ERR("%s: ffprobe failed on buffer (is ffprobe in PATH?)\n", __func__);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ctx->start_ffmpeg(0.0f)) {
|
||||
LOG_ERR("%s: failed to start ffmpeg on buffer (is ffmpeg in PATH?)\n", __func__);
|
||||
delete ctx;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx;
|
||||
#else
|
||||
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
void mtmd_helper_video_free(mtmd_helper_video * ctx) {
|
||||
#ifdef MTMD_VIDEO
|
||||
if (!ctx) return;
|
||||
ctx->stop_ffmpeg();
|
||||
delete ctx;
|
||||
#else
|
||||
LOG_ERR("%s: video is not supported in this build (MTMD_VIDEO is set to OFF)\n", __func__);
|
||||
#endif
|
||||
}
|
||||
|
||||
mtmd_helper_video_info mtmd_helper_video_get_info(const mtmd_helper_video * ctx) {
|
||||
#ifdef MTMD_VIDEO
|
||||
return ctx->info;
|
||||
#else
|
||||
GGML_ASSERT(false && "video is not supported in this build (MTMD_VIDEO is set to OFF)");
|
||||
#endif
|
||||
}
|
||||
|
||||
int32_t mtmd_helper_video_read_next(mtmd_helper_video * ctx,
|
||||
mtmd_bitmap ** out_bitmap, char ** out_text) {
|
||||
#ifdef MTMD_VIDEO
|
||||
if (!ctx) return -2;
|
||||
return ctx->read_next(out_bitmap, out_text);
|
||||
#else
|
||||
GGML_ASSERT(false && "video is not supported in this build (MTMD_VIDEO is set to OFF)");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -20,25 +20,39 @@ extern "C" {
|
||||
// BREAKING CHANGES are expected.
|
||||
//
|
||||
|
||||
struct mtmd_helper_video;
|
||||
typedef struct mtmd_helper_video mtmd_helper_video;
|
||||
|
||||
// Set callback for all future logging events.
|
||||
// If this is not called, or NULL is supplied, everything is output on stderr.
|
||||
// Note: this also call mtmd_log_set() internally
|
||||
MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
|
||||
|
||||
// Returns true if this build includes video support (MTMD_VIDEO was ON at compile time).
|
||||
MTMD_API bool mtmd_helper_support_video(mtmd_context * ctx);
|
||||
|
||||
struct mtmd_helper_bitmap_wrapper {
|
||||
mtmd_bitmap * bitmap;
|
||||
mtmd_helper_video * video_ctx;
|
||||
};
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a file
|
||||
// it calls mtmd_helper_bitmap_init_from_buf() internally
|
||||
// returns nullptr on failure
|
||||
// this function is thread-safe
|
||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder);
|
||||
MTMD_API struct mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder);
|
||||
|
||||
// helper function to construct a mtmd_bitmap from a buffer containing a file
|
||||
// supported formats:
|
||||
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
|
||||
// audio: formats supported by miniaudio: wav, mp3, flac
|
||||
// note: audio files will be auto-detected based on magic bytes
|
||||
// note:
|
||||
// - for now, video input is only supported via C++ helper functions
|
||||
// - audio files will be auto-detected based on magic bytes
|
||||
// - output bitmap will have FNV hash as the ID
|
||||
// returns nullptr on failure
|
||||
// this function is thread-safe
|
||||
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder);
|
||||
MTMD_API struct mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len, bool placeholder);
|
||||
|
||||
// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
|
||||
MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
|
||||
@@ -89,6 +103,56 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past);
|
||||
|
||||
//
|
||||
// video input helpers (requires ffmpeg/ffprobe installed on the system)
|
||||
// the notion of video only exists at the helper level, it is not visible to the core mtmd library
|
||||
//
|
||||
// NOTE: this implementation is model-agnostic, it can be used with any vision-capable model
|
||||
// however, it may not be accurate for some specific models
|
||||
// (this is expected for now, to keep the implementation simple)
|
||||
//
|
||||
|
||||
struct mtmd_helper_video_info {
|
||||
uint32_t width;
|
||||
uint32_t height;
|
||||
float fps; // effective fps (fps_target if set, else original video fps)
|
||||
int32_t n_frames; // estimated total frames at effective fps (-1 if unknown)
|
||||
};
|
||||
|
||||
struct mtmd_helper_video_init_params {
|
||||
float fps_target; // desired output fps; <= 0 means use the video's native fps, defaulted to 4.0f
|
||||
const char * ffmpeg_bin_dir; // directory containing ffmpeg/ffprobe binaries; NULL means search PATH
|
||||
int64_t timestamp_interval_ms; // interval for adding timestamp as text chunk (example: "[10m50.5s]"); <= 0 means no timestamp, defaulted to 5000ms
|
||||
// TODO @ngxson : allow "placeholder" bitmap output for counting tokens
|
||||
};
|
||||
|
||||
MTMD_API struct mtmd_helper_video_init_params mtmd_helper_video_init_params_default(void);
|
||||
|
||||
// returns NULL on failure (ffprobe not found, file unreadable, etc.)
|
||||
MTMD_API mtmd_helper_video * mtmd_helper_video_init(
|
||||
struct mtmd_context * mctx,
|
||||
const char * path,
|
||||
struct mtmd_helper_video_init_params params);
|
||||
|
||||
// Same as mtmd_helper_video_init(), but reads from an in-memory buffer.
|
||||
// The buffer is copied internally; the caller does not need to keep it alive.
|
||||
// Note: pipe input is not seekable, so seeking will use output-side seeking
|
||||
// (ffmpeg decodes and discards frames up to the target position).
|
||||
MTMD_API mtmd_helper_video * mtmd_helper_video_init_from_buf(
|
||||
struct mtmd_context * mctx,
|
||||
const unsigned char * buf, size_t len,
|
||||
struct mtmd_helper_video_init_params params);
|
||||
MTMD_API void mtmd_helper_video_free(mtmd_helper_video * ctx);
|
||||
MTMD_API struct mtmd_helper_video_info mtmd_helper_video_get_info(const mtmd_helper_video * ctx);
|
||||
|
||||
// Read the next item from the video stream; exactly one of out_bitmap or out_text is set per call.
|
||||
// *out_bitmap - heap-allocated; caller must free with mtmd_bitmap_free()
|
||||
// *out_text - heap-allocated (always via strdup/malloc); caller must free with free()
|
||||
// returns 0 on success, -1 on EOF, -2 on error
|
||||
MTMD_API int32_t mtmd_helper_video_read_next(mtmd_helper_video * ctx,
|
||||
mtmd_bitmap ** out_bitmap,
|
||||
char ** out_text);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
@@ -97,4 +161,16 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
|
||||
// C++ wrappers
|
||||
//
|
||||
|
||||
#ifdef __cplusplus
|
||||
namespace mtmd_helper {
|
||||
|
||||
// video-related C++ wrappers
|
||||
struct mtmd_helper_video_deleter {
|
||||
void operator()(mtmd_helper_video * val) { mtmd_helper_video_free(val); }
|
||||
};
|
||||
using video_ptr = std::unique_ptr<mtmd_helper_video, mtmd_helper_video_deleter>;
|
||||
|
||||
} // namespace mtmd_helper
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
+138
-29
@@ -35,6 +35,10 @@ struct mtmd_bitmap {
|
||||
std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
|
||||
bool is_audio = false; // true if the bitmap is audio
|
||||
|
||||
// lazy-loaded bitmap
|
||||
mtmd_bitmap_lazy_callback lazy_callback = nullptr;
|
||||
void * lazy_user_data = nullptr;
|
||||
|
||||
mtmd_bitmap(const unsigned char * data, uint32_t nx, uint32_t ny)
|
||||
: nx(nx), ny(ny), is_audio(false) {
|
||||
if (data) {
|
||||
@@ -732,30 +736,111 @@ void mtmd_free(mtmd_context * ctx) {
|
||||
|
||||
struct mtmd_tokenizer {
|
||||
mtmd_context * ctx;
|
||||
std::vector<const mtmd_bitmap *> bitmaps;
|
||||
|
||||
std::string input_text;
|
||||
bool add_special;
|
||||
bool parse_special;
|
||||
const llama_vocab * vocab;
|
||||
|
||||
struct part {
|
||||
std::string text;
|
||||
const mtmd_bitmap * bitmap;
|
||||
};
|
||||
std::vector<part> parts;
|
||||
// these will be freed when mtmd_tokenizer finishes
|
||||
std::vector<mtmd::bitmap> bm_from_lazy; // TODO @ngxson : refactor, free bm_from_lazy progressively
|
||||
std::vector<const char *> text_from_lazy;
|
||||
|
||||
mtmd_input_chunks cur;
|
||||
uint32_t n_images_added = 0; // 0-based index assigned to the next image chunk
|
||||
|
||||
~mtmd_tokenizer() {
|
||||
// note: mtmd::bitmap is already RAII
|
||||
for (auto & str : text_from_lazy) {
|
||||
free((void *)str);
|
||||
}
|
||||
}
|
||||
|
||||
mtmd_tokenizer(mtmd_context * ctx,
|
||||
const mtmd_input_text * text,
|
||||
const mtmd_bitmap ** bitmaps,
|
||||
size_t n_bitmaps) : ctx(ctx), bitmaps(bitmaps, bitmaps + n_bitmaps) {
|
||||
const mtmd_bitmap ** bmps,
|
||||
size_t n_bitmaps) : ctx(ctx) {
|
||||
add_special = text->add_special;
|
||||
parse_special = text->parse_special;
|
||||
input_text = text->text;
|
||||
vocab = ctx->vocab;
|
||||
|
||||
std::vector<const mtmd_bitmap *> bitmaps(bmps, bmps + n_bitmaps);
|
||||
auto parts_str = split_text(input_text, ctx->media_marker);
|
||||
size_t i_bm = 0;
|
||||
for (const auto & part : parts_str) {
|
||||
if (part == ctx->media_marker) {
|
||||
if (i_bm >= bitmaps.size()) {
|
||||
throw std::runtime_error(string_format("number of media markers in text (%zu) exceeds number of bitmaps (%zu)", i_bm + 1, bitmaps.size()));
|
||||
}
|
||||
parts.push_back({"", bitmaps[i_bm++]});
|
||||
} else {
|
||||
parts.push_back({std::move(part), nullptr});
|
||||
}
|
||||
}
|
||||
|
||||
size_t n_markers = 0;
|
||||
for (const auto & part : parts) {
|
||||
if (part.bitmap != nullptr) {
|
||||
n_markers++;
|
||||
}
|
||||
}
|
||||
if (n_markers != bitmaps.size()) {
|
||||
throw std::runtime_error(string_format("number of media markers in text (%zu) does not match number of bitmaps (%zu)", n_markers, bitmaps.size()));
|
||||
}
|
||||
|
||||
expand_lazy_bitmaps();
|
||||
}
|
||||
|
||||
void expand_lazy_bitmaps() {
|
||||
std::vector<part> expanded;
|
||||
expanded.reserve(parts.size());
|
||||
for (auto & p : parts) {
|
||||
if (p.bitmap != nullptr && p.bitmap->lazy_callback) {
|
||||
LOG_DBG("%s: expanding lazy bitmap\n", __func__);
|
||||
for (size_t i = 0;; i++) {
|
||||
char * out_str = nullptr;
|
||||
mtmd_bitmap * out_bm = nullptr;
|
||||
int res = p.bitmap->lazy_callback(i,
|
||||
p.bitmap->lazy_user_data,
|
||||
&out_bm,
|
||||
&out_str);
|
||||
if (out_bm && out_str) {
|
||||
throw std::runtime_error(string_format("lazy callback cannot return both bitmap and text"));
|
||||
}
|
||||
if (res == 0) {
|
||||
// OK, append the returned chunk; lazy part is not yet added
|
||||
if (out_bm) {
|
||||
auto & ptr = bm_from_lazy.emplace_back(out_bm); // remember to free it later
|
||||
expanded.push_back({"", ptr.ptr.get()});
|
||||
LOG_DBG("%s: lazy callback returned bitmap with dimensions %d x %d\n", __func__, out_bm->nx, out_bm->ny);
|
||||
} else if (out_str) {
|
||||
auto & ptr = text_from_lazy.emplace_back(out_str); // remember to free it later
|
||||
expanded.push_back({ptr, nullptr});
|
||||
LOG_DBG("%s: lazy callback returned text: %s\n", __func__, out_str);
|
||||
}
|
||||
} else if (res == -1) {
|
||||
// EOF: lazy part removes itself (not added to expanded)
|
||||
break;
|
||||
} else if (res == -2) {
|
||||
// error
|
||||
throw std::runtime_error(string_format("lazy callback returned error"));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
expanded.push_back(std::move(p));
|
||||
}
|
||||
}
|
||||
parts = std::move(expanded);
|
||||
}
|
||||
|
||||
int32_t tokenize(mtmd_input_chunks * output) {
|
||||
cur.entries.clear();
|
||||
std::vector<std::string> parts = split_text(input_text, ctx->media_marker);
|
||||
size_t i_bm = 0; // index of the current bitmap
|
||||
|
||||
// [QWEN_VIDEO] handle frame merging for models that support it (i.e. qwen-vl)
|
||||
int n_merge_frames = 1;
|
||||
@@ -764,53 +849,50 @@ struct mtmd_tokenizer {
|
||||
GGML_ASSERT(n_merge_frames <= 2 && "we only support merging maximum 2 images for now; open an issue if this model supports merging more");
|
||||
}
|
||||
|
||||
// Build merged_bitmaps: each entry is a group of 1 or 2 bitmaps.
|
||||
// For consecutive mergeable bitmap parts, merge them and collapse the second part out of this->parts.
|
||||
std::vector<std::vector<const mtmd_bitmap *>> merged_bitmaps;
|
||||
if (n_merge_frames > 1) {
|
||||
size_t i_bm_scan = 0;
|
||||
for (size_t i = 0; i < parts.size(); ++i) {
|
||||
if (parts[i] != ctx->media_marker) {
|
||||
if (parts[i].bitmap == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (i + 1 < parts.size()
|
||||
&& parts[i + 1] == ctx->media_marker
|
||||
&& i_bm_scan + 1 < bitmaps.size()) {
|
||||
const mtmd_bitmap * bm_a = bitmaps[i_bm_scan];
|
||||
const mtmd_bitmap * bm_b = bitmaps[i_bm_scan + 1];
|
||||
if (i + 1 < parts.size() && parts[i + 1].bitmap != nullptr) {
|
||||
const mtmd_bitmap * bm_a = parts[i].bitmap;
|
||||
const mtmd_bitmap * bm_b = parts[i + 1].bitmap;
|
||||
if (bm_a->can_batch_with(*bm_b)) {
|
||||
LOG_DBG("%s: merging 2 frames at bitmap index %zu and %zu\n", __func__, i_bm_scan, i_bm_scan + 1);
|
||||
LOG_DBG("%s: merging 2 frames at part index %zu and %zu\n", __func__, i, i + 1);
|
||||
merged_bitmaps.push_back({bm_a, bm_b});
|
||||
parts.erase(parts.begin() + i + 1); // remove the second marker
|
||||
i_bm_scan += 2;
|
||||
parts.erase(parts.begin() + i + 1); // collapse the second bitmap part
|
||||
continue;
|
||||
}
|
||||
}
|
||||
LOG_DBG("%s: no merging for bitmap index %zu\n", __func__, i_bm_scan);
|
||||
merged_bitmaps.push_back({bitmaps[i_bm_scan]});
|
||||
++i_bm_scan;
|
||||
LOG_DBG("%s: no merging for part index %zu\n", __func__, i);
|
||||
merged_bitmaps.push_back({parts[i].bitmap});
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < bitmaps.size(); ++i) {
|
||||
merged_bitmaps.push_back({bitmaps[i]});
|
||||
for (const auto & p : parts) {
|
||||
if (p.bitmap != nullptr) {
|
||||
merged_bitmaps.push_back({p.bitmap});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i_bm = 0;
|
||||
for (auto & part : parts) {
|
||||
if (part == ctx->media_marker) {
|
||||
// this is a marker, we should add the next bitmap
|
||||
size_t i_bm = 0;
|
||||
for (const auto & p : parts) {
|
||||
if (p.bitmap != nullptr) {
|
||||
if (i_bm >= merged_bitmaps.size()) {
|
||||
LOG_ERR("%s: error: number of bitmaps (%zu) does not match number of markers (%zu)\n",
|
||||
__func__, merged_bitmaps.size(), parts.size() - 1);
|
||||
return 1;
|
||||
}
|
||||
auto & bmps = merged_bitmaps[i_bm++];
|
||||
auto bmps = merged_bitmaps[i_bm++];
|
||||
int32_t res = add_media(bmps);
|
||||
if (res != 0) {
|
||||
return res;
|
||||
}
|
||||
} else {
|
||||
// this is a text part, we should add it as text
|
||||
add_text(part, parse_special);
|
||||
add_text(p.text, parse_special);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1236,8 +1318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
const mtmd_input_text * text,
|
||||
const mtmd_bitmap ** bitmaps,
|
||||
size_t n_bitmaps) {
|
||||
mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
|
||||
return tokenizer.tokenize(output);
|
||||
try {
|
||||
mtmd_tokenizer tokenizer(ctx, text, bitmaps, n_bitmaps);
|
||||
return tokenizer.tokenize(output);
|
||||
} catch (const std::exception & e) {
|
||||
LOG_ERR("%s: error: %s\n", __func__, e.what());
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t mtmd_encode_chunk(mtmd_context * ctx, const mtmd_input_chunk * chunk) {
|
||||
@@ -1373,6 +1460,10 @@ int mtmd_get_audio_sample_rate(const mtmd_context * ctx) {
|
||||
return clip_get_hparams(ctx->ctx_a)->audio_sample_rate;
|
||||
}
|
||||
|
||||
const char * mtmd_get_marker(const mtmd_context * ctx) {
|
||||
return ctx->media_marker.c_str();
|
||||
}
|
||||
|
||||
//
|
||||
// public API functions
|
||||
//
|
||||
@@ -1405,10 +1496,16 @@ uint32_t mtmd_bitmap_get_ny(const mtmd_bitmap * bitmap) {
|
||||
}
|
||||
|
||||
const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) {
|
||||
if (bitmap->is_placeholder()) {
|
||||
return nullptr;
|
||||
}
|
||||
return bitmap->get_ro_buf().data();
|
||||
}
|
||||
|
||||
size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) {
|
||||
if (bitmap->is_placeholder()) {
|
||||
return 0;
|
||||
}
|
||||
return bitmap->get_ro_buf().size();
|
||||
}
|
||||
|
||||
@@ -1428,6 +1525,18 @@ void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id) {
|
||||
}
|
||||
}
|
||||
|
||||
mtmd_bitmap * mtmd_bitmap_init_lazy(mtmd_context * ctx,
|
||||
const char * id,
|
||||
void * user_data,
|
||||
mtmd_bitmap_lazy_callback callback) {
|
||||
GGML_UNUSED(ctx); // reserved for future use
|
||||
mtmd_bitmap * bitmap = new mtmd_bitmap(nullptr, 0, 0);
|
||||
bitmap->lazy_callback = callback;
|
||||
bitmap->lazy_user_data = user_data;
|
||||
mtmd_bitmap_set_id(bitmap, id);
|
||||
return bitmap;
|
||||
}
|
||||
|
||||
void mtmd_bitmap_free(mtmd_bitmap * bitmap) {
|
||||
if (bitmap) {
|
||||
delete bitmap;
|
||||
|
||||
@@ -128,6 +128,9 @@ MTMD_API bool mtmd_support_audio(const mtmd_context * ctx);
|
||||
// return -1 if audio is not supported
|
||||
MTMD_API int mtmd_get_audio_sample_rate(const mtmd_context * ctx);
|
||||
|
||||
// get the current marker string
|
||||
MTMD_API const char * mtmd_get_marker(const mtmd_context * ctx);
|
||||
|
||||
// mtmd_bitmap
|
||||
//
|
||||
// if bitmap is image:
|
||||
@@ -156,6 +159,34 @@ MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap);
|
||||
MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap);
|
||||
MTMD_API void mtmd_bitmap_set_id(mtmd_bitmap * bitmap, const char * id);
|
||||
|
||||
// mtmd_bitmap lazy
|
||||
//
|
||||
// this is a special bitmap that:
|
||||
// - does not hold the actual data
|
||||
// - can be expanded into one or more chunks (either media to text chunks)
|
||||
// user must provide a callback to fill in the data when mtmd_tokenize() is called
|
||||
// this is useful for large video inputs:
|
||||
// - allow reading video frame by frame, without loading the entire video into memory
|
||||
// - allow tracking the whole video with a single ID (for example, the file hash)
|
||||
|
||||
// set (*out_bitmap) to non-nullptr to emit a bitmap chunk; it will be freed automatically
|
||||
// set (*out_text) to non-nullptr to emit a text chunk; it must be heap-allocated, null-terminated and will be freed automatically
|
||||
// either out_bitmap or out_text can be set, but not both
|
||||
// out_bitmap cannot be another lazy bitmap (no nested lazy allowed)
|
||||
// return value:
|
||||
// 0 on success
|
||||
// -1 on EOF (signal to mtmd_tokenize to move on)
|
||||
// -2 on error (signal to mtmd_tokenize to abort)
|
||||
typedef int(* mtmd_bitmap_lazy_callback)(
|
||||
size_t chunk_idx,
|
||||
void * user_data,
|
||||
mtmd_bitmap ** out_bitmap,
|
||||
char ** out_text);
|
||||
|
||||
MTMD_API mtmd_bitmap * mtmd_bitmap_init_lazy(mtmd_context * ctx,
|
||||
const char * id, // usually set to file hash
|
||||
void * user_data,
|
||||
mtmd_bitmap_lazy_callback callback);
|
||||
|
||||
// mtmd_input_chunks
|
||||
//
|
||||
|
||||
Binary file not shown.
@@ -1252,6 +1252,10 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type":
|
||||
|
||||
`parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template).
|
||||
|
||||
For multimodal input:
|
||||
- Content type `image_url` and `input_audio` are the same as OAI schema
|
||||
- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input
|
||||
|
||||
*Examples:*
|
||||
|
||||
You can use either Python `openai` library with appropriate checkpoints:
|
||||
|
||||
@@ -701,29 +701,19 @@ size_t validate_utf8(const std::string& text) {
|
||||
return len;
|
||||
}
|
||||
|
||||
// Computes FNV-1a hash of the data
|
||||
static std::string fnv_hash(const uint8_t * data, size_t len) {
|
||||
const uint64_t fnv_prime = 0x100000001b3ULL;
|
||||
uint64_t hash = 0xcbf29ce484222325ULL;
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
hash ^= data[i];
|
||||
hash *= fnv_prime;
|
||||
}
|
||||
return std::to_string(hash);
|
||||
}
|
||||
|
||||
server_tokens process_mtmd_prompt(mtmd_context * mctx, const std::string & prompt, const std::vector<raw_buffer> & files, bool is_placeholder) {
|
||||
// these will be freed upon going out of scope
|
||||
mtmd::bitmaps bitmaps;
|
||||
std::vector<mtmd_helper::video_ptr> videos;
|
||||
for (auto & file : files) {
|
||||
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size(), is_placeholder));
|
||||
if (!bmp.ptr) {
|
||||
auto out = mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size(), is_placeholder);
|
||||
if (!out.bitmap) {
|
||||
throw std::runtime_error("Failed to load image or audio file");
|
||||
}
|
||||
// calculate bitmap hash (for KV caching)
|
||||
std::string hash = fnv_hash(bmp.data(), bmp.n_bytes());
|
||||
bmp.set_id(hash.c_str());
|
||||
bitmaps.entries.push_back(std::move(bmp));
|
||||
bitmaps.entries.emplace_back(out.bitmap);
|
||||
if (out.video_ctx) {
|
||||
videos.emplace_back(out.video_ctx);
|
||||
}
|
||||
}
|
||||
// process prompt
|
||||
std::vector<server_tokens> inputs;
|
||||
@@ -1023,6 +1013,20 @@ json oaicompat_chat_params_parse(
|
||||
p["text"] = get_media_marker();
|
||||
p.erase("input_audio");
|
||||
|
||||
} else if (type == "input_video") {
|
||||
if (!opt.allow_video) {
|
||||
throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj");
|
||||
}
|
||||
|
||||
json input_video = json_value(p, "input_video", json::object());
|
||||
std::string data = json_value(input_video, "data", std::string());
|
||||
auto decoded_data = base64_decode(data); // expected to be base64 encoded
|
||||
out_files.push_back(decoded_data);
|
||||
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = get_media_marker();
|
||||
p.erase("input_video");
|
||||
|
||||
} else if (type != "text") {
|
||||
throw std::invalid_argument("unsupported content[].type");
|
||||
}
|
||||
|
||||
@@ -294,6 +294,7 @@ struct server_chat_params {
|
||||
common_chat_templates_ptr tmpls;
|
||||
bool allow_image;
|
||||
bool allow_audio;
|
||||
bool allow_video;
|
||||
bool enable_thinking = true;
|
||||
int reasoning_budget = -1;
|
||||
std::string reasoning_budget_message;
|
||||
|
||||
@@ -1247,6 +1247,7 @@ private:
|
||||
/* tmpls */ std::move(chat_templates),
|
||||
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
|
||||
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
|
||||
/* allow_video */ mctx ? mtmd_helper_support_video(mctx) : false,
|
||||
/* enable_thinking */ enable_thinking,
|
||||
/* reasoning_budget */ params_base.sampling.reasoning_budget_tokens,
|
||||
/* reasoning_budget_msg */ params_base.sampling.reasoning_budget_message,
|
||||
@@ -3586,6 +3587,7 @@ server_context_meta server_context::get_meta() const {
|
||||
/* has_mtmd */ impl->mctx != nullptr,
|
||||
/* has_inp_image */ impl->chat_params.allow_image,
|
||||
/* has_inp_audio */ impl->chat_params.allow_audio,
|
||||
/* has_inp_video */ impl->chat_params.allow_video,
|
||||
/* json_ui_settings */ impl->json_ui_settings,
|
||||
/* json_webui_settings */ impl->json_webui_settings, // Deprecated
|
||||
/* slot_n_ctx */ impl->get_slot_n_ctx(),
|
||||
@@ -4183,6 +4185,7 @@ void server_routes::init_routes() {
|
||||
{ "model_path", meta->model_path },
|
||||
{ "modalities", json {
|
||||
{"vision", meta->has_inp_image},
|
||||
{"video", meta->has_inp_video},
|
||||
{"audio", meta->has_inp_audio},
|
||||
} },
|
||||
{ "media_marker", get_media_marker() },
|
||||
@@ -4976,7 +4979,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_count_tokens(const l
|
||||
n_tokens = tokenize_mixed(vocab, prompt, true, true).size();
|
||||
}
|
||||
|
||||
json response = {{"input_tokens", static_cast<int>(n_tokens)}};
|
||||
json response = {{"input_tokens", static_cast<int64_t>(n_tokens)}};
|
||||
if (is_oai) {
|
||||
response["object"] = "response.input_tokens";
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ struct server_context_meta {
|
||||
bool has_mtmd;
|
||||
bool has_inp_image;
|
||||
bool has_inp_audio;
|
||||
bool has_inp_video;
|
||||
json json_ui_settings; // Primary: new name
|
||||
json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat)
|
||||
int slot_n_ctx;
|
||||
|
||||
@@ -1393,6 +1393,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() {
|
||||
//
|
||||
void server_task_result_cmpl_partial::update(task_result_state & state) {
|
||||
is_updated = true;
|
||||
if (is_begin) {
|
||||
return; // begin marker only flushes headers, skip parsing
|
||||
}
|
||||
state.update_chat_msg(content, true, oaicompat_msg_diffs);
|
||||
|
||||
// Copy current state for use in to_json_*() (reflects state BEFORE this chunk)
|
||||
|
||||
Reference in New Issue
Block a user